def mask_heads_forward_with_coords( self, mask_feats, mask_feat_stride, instances ): # 1/8 P3 对应的原图真实坐标 locations = compute_locations( mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device ) n_inst = len(instances) # 在fcos每个点采样时,会记录所属img id,最后根据分类pos id筛选, # 最后在实例分割这块根据ins个数复制相应image出来的mask feature # 即(2, n, h, w) -> (ins, n, hxw) im_inds = instances.im_inds mask_head_params = instances.mask_head_params N, _, H, W = mask_feats.size() if not self.disable_rel_coords: # 之前在fcos记录了每个pos location的中心位置,在这里生成相对坐标 # 即target所在的点变成0,其余的变成和它的相对距离 instance_locations = instances.locations # (39, 1, 2) - (1, hxw, 2) relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) # (39, hxw, 2) --> (39, 2, hxw) relative_coords = relative_coords.permute(0, 2, 1).float() # 给每个相对距离乘以一个衰减系数,如果instance越大,即来自高层特征, # 则会给它的相对距离更大的衰减因子 soi = self.sizes_of_interest.float()[instances.fpn_levels] relative_coords = relative_coords / soi.reshape(-1, 1, 1) relative_coords = relative_coords.to(dtype=mask_feats.dtype) mask_head_inputs = torch.cat([ relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) ], dim=1) # torch.Size([39, 10, 12880]) torch.Size([2, 8, 92, 140]) torch.Size([39, 2, 12880]) # print(mask_head_inputs.shape, mask_feats.shape, relative_coords.shape) else: mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) weights, biases = parse_dynamic_params( mask_head_params, self.channels, self.weight_nums, self.bias_nums ) # 现在的mask_head_inputs的每个instance包含了全图的mask feature # 以及所预测instance中心点的相对距离信息 # torch.Size([1, 580, 100, 136]) # [torch.Size([464, 10, 1, 1]), torch.Size([464, 8, 1, 1]), torch.Size([58, 8, 1, 1])] # [torch.Size([464]), torch.Size([464]), torch.Size([58])] mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) mask_logits = mask_logits.reshape(-1, 1, H, W) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) return mask_logits.sigmoid()
def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride, instances, out_size): locations = compute_locations(mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device) n_inst = len(instances) im_inds = instances.im_inds mask_head_params = instances.mask_head_params N, _, H, W = mask_feats.size() instance_locations = instances.locations new_locations = instance_locations / 8. # n, 2(x, y) new_locations_x = new_locations[:, 0].clamp(min=0, max=W - 1) // int( (W / self.grid_num)) new_locations_y = new_locations[:, 1].clamp(min=0, max=H - 1) // int( (H / self.grid_num)) new_locations_ind = (new_locations_x + self.grid_num * new_locations_y).to(torch.int64) if not self.disable_rel_coords: relative_coords = instance_locations.reshape( -1, 1, 2) - locations.reshape(1, -1, 2) relative_coords = relative_coords.permute(0, 2, 1).float() soi = self.sizes_of_interest.float()[instances.fpn_levels] relative_coords = relative_coords / soi.reshape(-1, 1, 1) relative_coords = relative_coords.to(dtype=mask_feats.dtype) mask_head_inputs = torch.cat([ relative_coords, mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) ], dim=1) mask_head_inputs = mask_head_inputs.reshape(n_inst, 10, H, W) else: mask_head_inputs = mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) mask_head_inputs = mask_head_inputs.reshape( n_inst, self.in_channels, H, W) weights, biases = parse_dynamic_params(mask_head_params, self.channels, self.weight_nums, self.bias_nums) mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst, new_locations_ind) mask_logits = mask_logits.reshape(-1, 1, int(H / self.grid_num), int(W / self.grid_num)) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_logits = F.interpolate(mask_logits, size=out_size, mode='bilinear', align_corners=True) return mask_logits.sigmoid()
def compute_locations(self, features): locations = [] for level, feature in enumerate(features): h, w = feature.size()[-2:] locations_per_level = compute_locations(h, w, self.fpn_strides[level], feature.device) locations.append(locations_per_level) return locations
def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride, instances): locations = compute_locations(mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device) n_inst = len(instances) im_inds = instances.im_inds mask_head_params = instances.mask_head_params N, _, H, W = mask_feats.size() if not self.disable_rel_coords: instance_locations = instances.locations relative_coords = instance_locations.reshape( -1, 1, 2) - locations.reshape(1, -1, 2) relative_coords = relative_coords.permute(0, 2, 1).float() soi = self.sizes_of_interest.float()[instances.fpn_levels] relative_coords = relative_coords / soi.reshape(-1, 1, 1) relative_coords = relative_coords.to(dtype=mask_feats.dtype) mask_head_inputs = torch.cat([ relative_coords, mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) ], dim=1) else: mask_head_inputs = mask_feats[im_inds].reshape( n_inst, self.in_channels, H * W) mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) weights, biases = parse_dynamic_params(mask_head_params, self.channels, self.weight_nums, self.bias_nums) mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) mask_logits = mask_logits.reshape(-1, 1, H, W) assert mask_feat_stride >= self.mask_out_stride assert mask_feat_stride % self.mask_out_stride == 0 mask_logits = aligned_bilinear( mask_logits, int(mask_feat_stride / self.mask_out_stride)) return mask_logits
def mask_heads_forward_with_coords( self, mask_feats, mask_feat_stride, instances, gt_instances=None ): n_inst = len(instances) im_inds = instances.im_inds mask_head_params = instances.mask_head_params instance_locations = instances.locations levels = instances.fpn_levels # 0, 1, 2, 3, 4 => P3/P4/P5/P6/P7 # (3, 4) ==> 1 # (1, 2) ==> 2 # 0 ==> 4 ind1 = (levels > self.split[2]) & (levels <= self.split[3]) ind2 = (levels > self.split[1]) & (levels <= self.split[2]) ind4 = (levels > self.split[0]) & (levels <= self.split[1]) weights, biases = parse_dynamic_params( mask_head_params, self.channels, self.weight_nums, self.bias_nums, [ind1, ind2, ind4], self.concat ) N, _, H, W = mask_feats.size() num_layers = self.num_layers mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) if not self.disable_rel_coords: base_locations = compute_locations( mask_feats.size(2), mask_feats.size(3), stride=mask_feat_stride, device=mask_feats.device ) mask_head_inputs = self.add_locaions_info(base_locations, instances, mask_head_inputs) mask_head_inputs = mask_head_inputs.reshape(n_inst, self.in_channels + 2, H, W) else: mask_head_inputs = mask_head_inputs.reshape(n_inst, self.in_channels, H, W) def get_loaction_weights_bias(level_ind, locations, i_weight, i_bias, grid_num=4, grid_inside_num=3): new_locations = locations / 4. # n, 2(x, y) new_locations[:, 0] = new_locations[:, 0].clamp(min=0, max=2*W - 1) new_locations[:, 1] = new_locations[:, 1].clamp(min=0, max=2*H - 1) new_locations_x = new_locations[:, 0] // int((2*W / grid_num)) new_locations_y = new_locations[:, 1] // int((2*H / grid_num)) new_locations_ind = (new_locations_x + grid_num * new_locations_y).to(torch.int64) if not self.concat or not len(locations): return level_ind, new_locations_ind, i_weight, i_bias, None new_inside_locations = torch.zeros_like(new_locations) new_inside_locations[:, 0] = new_locations[:, 0] % int((2*W / grid_num)) new_inside_locations[:, 1] = new_locations[:, 1] % int((2*H / grid_num)) new_locations_inside_x = new_inside_locations[:, 0] // round((2*W / (grid_num * grid_inside_num)) + 0.5) new_locations_inside_y = new_inside_locations[:, 1] // round((2*H / (grid_num * grid_inside_num)) + 0.5) relative_inside_location = (new_locations_inside_x + grid_inside_num * new_locations_inside_y).to( torch.int64) assert (relative_inside_location < grid_inside_num ** 2).all(), \ (W, H, new_inside_locations, new_locations_inside_x, new_locations_inside_y, relative_inside_location) # import time # time1 = time.time() # cate part grid_map = torch.from_numpy(get_grid_map(grid_inside_num)).to(device=mask_feats.device) new_locations_ind = new_locations_ind.cpu().numpy() maps = grid_map[relative_inside_location].clone() # n, 9, 2 n, _, _ = maps.shape maps[:, :, 0] = maps[:, :, 0] + new_locations_x.repeat(9).reshape(9, -1).T maps[:, :, 1] = maps[:, :, 1] + new_locations_y.repeat(9).reshape(9, -1).T maps = maps.reshape(-1, 2) ind_valid = (maps[:, 0] >= 0) & (maps[:, 0] < grid_num) & (maps[:, 1] >= 0) & (maps[:, 1] < grid_num) maps = maps[ind_valid] final_locations_ind = (maps[:, 0] + grid_num * maps[:, 1]).to(dtype=torch.int64, device=mask_feats.device) param_ind = ind_valid.reshape(n, 9) param_ind = param_ind.T*torch.arange(1, n+1).to(device=mask_feats.device) param_ind = param_ind.T.flatten() param_ind = param_ind[param_ind > 0] - 1 param_ind = param_ind.to(dtype=torch.int64, device=mask_feats.device) gt_ind = torch.arange(0, n*9)[ind_valid].to(dtype=torch.int64, device=mask_feats.device) if not len(param_ind): return level_ind, new_locations_ind, i_weight, i_bias, None for l in range(num_layers): assert len(param_ind), param_ind i_weight[l] = i_weight[l][param_ind] i_bias[l] = i_bias[l][param_ind] if l < num_layers - 1: n, c, _, _, _ = i_weight[l].shape i_weight[l] = i_weight[l].reshape(n * c, -1, 1, 1) i_bias[l] = i_bias[l].reshape(n * c) return param_ind, final_locations_ind, i_weight, i_bias, gt_ind param_ind_2, new_ind2, new_weights_2, new_biases_2, gt_ind_2 = get_loaction_weights_bias( ind2, instance_locations[ind2], weights[1], biases[1], grid_num=self.grid_num[1] ) param_ind_4, new_ind4, new_weights_4, new_biases_4, gt_ind_4 = get_loaction_weights_bias( ind4, instance_locations[ind4], weights[2], biases[2], grid_num=self.grid_num[2] ) mask_head_inputs1 = mask_head_inputs[ind1] if not self.concat: mask_head_inputs2 = mask_head_inputs[ind2] mask_head_inputs4 = mask_head_inputs[ind4] inds_list = [ind1, ind2, ind4] else: mask_head_inputs2 = mask_head_inputs[param_ind_2] mask_head_inputs4 = mask_head_inputs[param_ind_4] inds_list = [(ind1, ind1), (ind2, param_ind_2), (ind4, param_ind_4)] n_inst1 = mask_head_inputs1.shape[0] n_inst2 = mask_head_inputs2.shape[0] n_inst4 = mask_head_inputs4.shape[0] if gt_instances is not None: gt_inds = instances.gt_inds gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) # 1/4 gt_boxes = torch.cat([per_im.gt_boxes.tensor for per_im in gt_instances]) # 1 if not self.concat: gt_inds1 = gt_inds[ind1] crop_gt_bitmasks1, out_size1 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[0]) gt_bitmasks1 = crop_gt_bitmasks1[gt_inds1].unsqueeze(dim=1).to(dtype=mask_feats.dtype) gt_inds2 = gt_inds[ind2] crop_gt_bitmasks2, out_size2 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[1]) gt_inds4 = gt_inds[ind4] crop_gt_bitmasks4, out_size4 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[2]) gt_bitmasks2 = crop_gt_bitmasks2[gt_inds2] gt_bitmasks4 = crop_gt_bitmasks4[gt_inds4] else: gt_inds1 = gt_inds[ind1] crop_gt_bitmasks1, out_size1 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[0]) gt_bitmasks1 = crop_gt_bitmasks1[gt_inds1].unsqueeze(dim=1).to(dtype=mask_feats.dtype) crop_gt_bitmasks2 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[1]) crop_gt_bitmasks4 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[2]) gt_inds2 = gt_inds[ind2] gt_inds4 = gt_inds[ind4] gt_bitmasks2 = [crop_gt_bitmasks2[ind] for ind in gt_inds2.tolist()] gt_bitmasks4 = [crop_gt_bitmasks4[ind] for ind in gt_inds4.tolist()] gt_bitmasks2 = torch.cat(gt_bitmasks2)[gt_ind_2] if len(gt_inds2) else None gt_bitmasks4 = torch.cat(gt_bitmasks4)[gt_ind_4] if len(gt_inds4) else None out_size2 = gt_bitmasks2.shape[1:] if len(gt_inds2) else None out_size4 = gt_bitmasks4.shape[1:] if len(gt_inds4) else None gt_bitmasks2 = gt_bitmasks2.unsqueeze(dim=1).to(dtype=mask_feats.dtype) if len(gt_inds2) else None gt_bitmasks4 = gt_bitmasks4.unsqueeze(dim=1).to(dtype=mask_feats.dtype) if len(gt_inds4) else None gt_bitmasks_list = [gt_bitmasks1, gt_bitmasks2, gt_bitmasks4] else: gt_bitmasks_list = [] out_size1 = (int(H * 2), int(W * 2)) out_size2 = (H, W) out_size4 = (int(H / 2), int(W / 2)) # out_size1 = (int(H * 2), int(W * 2)) # out_size2 = (int(H * 2), int(W * 2)) # out_size4 = (int(H * 2), int(W * 2)) out_size = {1: out_size1, 2: out_size2, 4: out_size4} mask_logits1 = self.mask_heads_forward( mask_head_inputs1, out_size[self.grid_num[0]], weights[0], biases[0], n_inst1, [], self.grid_num[0]) mask_logits2 = self.mask_heads_forward( mask_head_inputs2, out_size[self.grid_num[1]], weights[1], biases[1], n_inst2, new_ind2, self.grid_num[1]) mask_logits4 = self.mask_heads_forward( mask_head_inputs4, out_size[self.grid_num[2]], weights[2], biases[2], n_inst4, new_ind4, self.grid_num[2]) mask_logits_list = [mask_logits1, mask_logits2, mask_logits4] return mask_logits_list, gt_bitmasks_list, inds_list, [None, gt_ind_2, gt_ind_4]