def forward(self, features: List[torch.Tensor], iuv_feats: torch.Tensor, rel_coord: Any, abs_coord: Any, fg_mask: Any, ins_mask_list=None): for i, _ in enumerate(self.in_features): if i == 0: x = self.scale_heads[i](features[i]) else: x = x + self.scale_heads[i](features[i]) if rel_coord is not None: x = torch.cat([x, rel_coord], dim=1) if abs_coord is not None: x = torch.cat([x, abs_coord], dim=1) # if skeleton_feats is not None: # x = torch.cat([x,skeleton_feats], dim=1) # pdb.set_trace() if rel_coord is not None or abs_coord is not None: x = self.comb_pe_conv(x) x = x * fg_mask if self.use_ins_gn: ## dense to sparse N, C, H, W = x.shape coord = compute_grid(H, W, device=x.device, norm=False) # sparse_coord_batch = [] # sparse_feat_batch = [] ins_indices_batch = [] # ins_indices_len = [] # ins_cnt = 0 for n in range(N): # m = fg_mask[n:n+1] x_indices = coord[0] y_indices = coord[1] # pdb.set_trace() # bg_and_ins = torch.cat([m[0],ins_mask_list[n].float()], dim=0) # ins_indices = torch.argmax(bg_and_ins, dim=0)[m[0,0]>0] + ins_cnt # try: # pdb.set_trace() logit_bg_fg = torch.cat( [(1 - fg_mask[n]) * 99999., ins_mask_list[n].float()], dim=0) ins_indices = torch.argmax(logit_bg_fg, dim=0) - 1 ## set bg to -1 ins_indices[ins_indices >= 0] = ins_indices[ins_indices >= 0] #+ ins_cnt ins_indices_batch.append(ins_indices) # ins_cnt += ins_mask_list[n].shape[0] - 1 ## exclude bg class ins_indices_batch = torch.stack(ins_indices_batch, dim=0) x = self.densepose_head(x, ins_indices_batch) else: x = self.densepose_head(x) x = self.predictor(x) return x
def forward(self, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, mask_out_bg=False): N, _, H, W = iuv_feats.size() if self.use_rel_coords: if self.use_pos_emb: rel_coord = self.position_embedder(rel_coord) else: rel_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) coord = rel_coord if self.use_abs_coords: abs_coord = compute_grid( H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1) if self.use_pos_emb: abs_coord = self.position_embedder(abs_coord) else: abs_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) coord = torch.cat([abs_coord, coord], dim=1) if mask_out_bg: fg_mask = s_logits.detach() fg_mask_list = [] for i in range(N): fg_mask_list.append( torch.max(fg_mask[instances.im_inds == i], dim=0, keepdim=True)[0]) fg_mask = torch.cat(fg_mask_list, dim=0).detach() # if mask_out_bg_feats=="hard": fg_mask = (fg_mask > 0.05).float() fg_mask = self._torch_dilate(fg_mask, kernel_size=3) else: fg_mask = torch.ones([N, 1, H, W], device=s_logits.device) x = iuv_feats for layer in self.layers: # pdb.set_trace() x = layer(torch.cat([coord, x], dim=1) * fg_mask) iuv_logit = x # iuv_logit = self.tower(iuv_head_inputs) assert iuv_feat_stride >= self.iuv_out_stride assert iuv_feat_stride % self.iuv_out_stride == 0 iuv_logit = aligned_bilinear( iuv_logit, int(iuv_feat_stride / self.iuv_out_stride)) return iuv_logit
def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask=None, gt_instances=None): N, _, H, W = iuv_feats.size() if self.use_rel_coords: if self.use_pos_emb: rel_coord = self.position_embedder(rel_coord) else: rel_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) if self.use_abs_coords: abs_coord = compute_grid( H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1) if self.use_pos_emb: abs_coord = self.position_embedder(abs_coord) else: abs_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1) iuv_head_inputs0 = iuv_head_inputs iuv_logit0 = self.tower0(iuv_head_inputs0) iuv_head_inputs1 = F.avg_pool2d(iuv_head_inputs0, kernel_size=3, stride=2) iuv_logit1 = self.tower1(iuv_head_inputs1) iuv_logit1 = F.interpolate(iuv_logit1, size=iuv_logit0.shape[-2:]) iuv_head_inputs2 = F.avg_pool2d(iuv_head_inputs1, kernel_size=3, stride=2) iuv_logit2 = self.tower2(iuv_head_inputs2) iuv_logit2 = F.interpolate(iuv_logit2, size=iuv_logit0.shape[-2:]) # attn = F.softmax(self.tower_attn(rel_coord), dim=1) # pdb.set_trace() # iuv_logit = iuv_logit0*attn[:,0:1] + iuv_logit1*attn[:,1:2] + iuv_logit2*attn[:,2:3] iuv_logit = torch.cat([iuv_logit0, iuv_logit1, iuv_logit2], dim=1) iuv_logit = self.tower_out(iuv_logit) assert iuv_feat_stride >= self.iuv_out_stride assert iuv_feat_stride % self.iuv_out_stride == 0 iuv_logit = aligned_bilinear( iuv_logit, int(iuv_feat_stride / self.iuv_out_stride)) return iuv_logit
def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask=None, gt_instances=None): N, _, H, W = iuv_feats.size() if self.use_rel_coords: if self.use_pos_emb: rel_coord = self.position_embedder(rel_coord) else: rel_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) if self.use_abs_coords: abs_coord = compute_grid( H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1) if self.use_pos_emb: abs_coord = self.position_embedder(abs_coord) else: abs_coord = torch.zeros( [N, 2, H, W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1) iuv_logit = self.tower(iuv_head_inputs) assert iuv_feat_stride >= self.iuv_out_stride assert iuv_feat_stride % self.iuv_out_stride == 0 iuv_logit = aligned_bilinear( iuv_logit, int(iuv_feat_stride / self.iuv_out_stride)) return iuv_logit
def forward(self, features: List[torch.Tensor], iuv_feats: torch.Tensor, rel_coord: torch.Tensor, abs_coord: torch.Tensor, fg_mask: torch.Tensor, ins_mask_list: List[torch.Tensor]): # assert fg_mask.min()==0, "the fg_mask is all 1" if not self.use_agg_feat: for i, _ in enumerate(self.in_features): if i == 0: x = self.scale_heads[i](features[i]) else: x = x + self.scale_heads[i](features[i]) else: x = iuv_feats if rel_coord is not None: x = torch.cat([x,rel_coord], dim=1) if abs_coord is not None: x = torch.cat([x,abs_coord], dim=1) # pdb.set_trace() # if skeleton_feats is not None: # x = torch.cat([x,skeleton_feats], dim=1) if rel_coord is not None or abs_coord is not None: # if isinstance(self.comb_pe_conv, LambdaLayer): # x = self.comb_pe_conv(x) # else: x = self.comb_pe_conv(x) if self.use_dropout: x = self.dropout_layer(x) if self.use_san: # x = self.san_blk_1(x*fg_mask) x = self.san_blk_1(x) ## dense to sparse N, C, H, W = x.shape coord = compute_grid(H, W, device=x.device, norm=False) sparse_coord_batch = [] sparse_feat_batch = [] ins_indices_batch = [] ins_indices_len = [] ins_cnt = 0 for n in range(N): m = fg_mask[n:n+1] x_indices = coord[0][m[0,0]>0] y_indices = coord[1][m[0,0]>0] # if self.use_ins_gn: ins_indices = torch.argmax(ins_mask_list[n].float(), dim=0)[m[0,0]>0] + ins_cnt ins_indices_batch.append(ins_indices) ins_cnt += ins_mask_list[n].shape[0] ins_indices_len.append(torch.sum(ins_mask_list[n],dim=[1,2])) b_indices = torch.ones_like(x_indices)*n sparse_coord = torch.stack([b_indices,y_indices,x_indices],dim=-1).int() sparse_coord_batch.append(sparse_coord) sparse_feat = x[n:n+1] sparse_feat = sparse_feat[m.expand_as(sparse_feat)>0].reshape([C,-1]).permute([1,0]) sparse_feat_batch.append(sparse_feat) sparse_coord_batch = torch.cat(sparse_coord_batch,dim=0) sparse_feat_batch = torch.cat(sparse_feat_batch,dim=0) # pdb.set_trace() # if self.use_ins_gn: # x = spconv.SparseConvTensor(sparse_feat_batch, sparse_coord_batch, (H,W), ins_cnt) # else: x = spconv.SparseConvTensor(sparse_feat_batch, sparse_coord_batch, (H,W), N) # pdb.set_trace() # if self.use_ins_gn: ins_indices_batch = torch.cat(ins_indices_batch,dim=0) ins_indices_len = torch.cat(ins_indices_len,dim=0) x = self.densepose_head(x, ins_indices_batch, ins_indices_len) # else: # x = self.densepose_head(x) # x = x * fg_mask # x = self.densepose_head(x, ins_mask_list) if self.predictor_conv_type=="sparse": x = self.predictor(x).dense() else: # pdb.set_trace() x = x.dense() # if self.checkpoint_grad_num>0 and len(self.bbox_tower)>0: # modules = [module for k, module in self.bbox_tower._modules.items()] # bbox_tower = checkpoint.checkpoint_sequential(modules,1,feature) # else: # bbox_tower = self.bbox_tower(feature) if self.checkpoint_grad_num>0: x = checkpoint.checkpoint(self.custom(self.predictor), x) else: x = self.predictor(x) return x
def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask, gt_instances=None): N, _, H, W = iuv_feats.size() if self.use_rel_coords: if self.use_pos_emb: rel_coord = self.position_embedder(rel_coord) else: rel_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) if self.use_abs_coords: abs_coord = compute_grid(H, W, device=iuv_feats.device)[None,...].repeat(N,1,1,1) if self.use_pos_emb: abs_coord = self.position_embedder(abs_coord) else: abs_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype) iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1) # fg_mask = s_logits.detach() # fg_mask_list = [] # for i in range(N): # fg_mask_list.append(torch.max(fg_mask[instances.im_inds==i], dim=0, keepdim=True)[0]) # fg_mask = torch.cat(fg_mask_list, dim=0).detach() # # if mask_out_bg_feats=="hard": # fg_mask = (fg_mask>0.05).float() # # fg_mask = self._torch_dilate(fg_mask, kernel_size=3) fg_mask = self._torch_dilate(fg_mask, kernel_size=3) # pdb.set_trace() # import imageio # imageio.imwrite("tmp/fg_mask_dilate5.png",fg_mask[0,0].detach().cpu().numpy()) x = iuv_head_inputs if self.use_partial_norm: for layer in self.layers: if isinstance(layer,Conv2d) or isinstance(layer,PartialConv2d): # x = layer(x*fg_mask) x = layer(x) elif isinstance(layer,nn.GroupNorm): fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True)[:,None,...] "Implement partial GN" x = x*fg_mask n,c,h,w = x.shape # mid_layer = [t for t in layer.named_children()][1][1] # assert isinstance(mid_layer,nn.GroupNorm) num_groups = layer.num_groups x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2) x_group_mean = torch.mean(x_group, dim=[-3,-2,-1], keepdim=True) x_group_std = torch.std(x_group, dim=[-3,-2,-1], keepdim=True) x_group_mean = x_group_mean.repeat(1,1,num_groups,1,1).reshape([n,c,1,1]) x_group_std = x_group_std.repeat(1,1,num_groups,1,1).reshape([n,c,1,1]) x_group_mean_p = torch.sum(x_group, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum x_group_std_p = torch.sqrt(torch.sum((x_group-x_group_mean_p)**2+1e-5, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum) x_group_mean_p = x_group_mean_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1]) x_group_std_p = x_group_std_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1]) gamma, beta = layer.parameters() gamma, beta = gamma[None,...,None,None], beta[None,...,None,None] # pdb.set_trace() x = layer(x) x = (x - beta) / gamma * x_group_std + x_group_mean x = (x - x_group_mean_p) / x_group_std_p * gamma + beta (x - x_group_mean) / x_group_std * gamma + beta x = (x - x_group_mean_p) / x_group_std_p * gamma + beta # x = (x-beta)/gamma fg_mask_sum + beta elif isinstance(layer,nn.BatchNorm2d): fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True) # "Implement partial BN" "Implement bbox BN" # x = x*fg_mask n,c,h,w = x.shape # mid_layer = [t for t in layer.named_children()][1][1] # assert isinstance(mid_layer,nn.GroupNorm) # num_groups = layer.num_groups # x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2) # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True) # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True) x_mean_p = torch.sum(x*fg_mask, dim=[0,-2,-1], keepdim=True)/fg_mask_sum x_std_p = torch.sqrt(torch.sum((x*fg_mask-x_mean_p)**2+1e-5, dim=[0,-2,-1], keepdim=True)/fg_mask_sum) gamma, beta = layer.parameters() gamma, beta = gamma[None,...,None,None], beta[None,...,None,None] # x = layer(x) # x = (x - beta) / gamma * x_std + x_mean # x = (x - x_mean_p) / x_std_p * gamma + beta # pdb.set_trace() x = (x - x_mean_p) / x_std_p * gamma + beta # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True) # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True) # x = (x - x_mean) / x_std * gamma + beta # x = layer(x) # print(gamma.mean(), beta.mean()) # x = (x-beta)/gamma fg_mask_sum + beta else: x = layer(x) else: for layer in self.layers: if isinstance(layer,LambdaLayer): x = layer(x) else: x = layer(x*fg_mask) iuv_logit = x # iuv_logit = x*fg_mask # iuv_logit = self.tower(iuv_head_inputs) assert iuv_feat_stride >= self.iuv_out_stride assert iuv_feat_stride % self.iuv_out_stride == 0 iuv_logit = aligned_bilinear(iuv_logit, int(iuv_feat_stride / self.iuv_out_stride)) return iuv_logit
def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask, gt_instances=None, ins_mask_list=None): # assert not self.use_abs_coords fea0 = fpn_features[self.in_features[0]] N, _, H, W = fea0.shape if self.use_rel_coords: if self.use_pos_emb: rel_coord = self.position_embedder(rel_coord) else: rel_coord = None if self.use_abs_coords: abs_coord = compute_grid(H, W, device=fea0.device)[None, ...].repeat( N, 1, 1, 1) if self.use_pos_emb: abs_coord = self.position_embedder(abs_coord) else: abs_coord = None features = [fpn_features[f] for f in self.in_features] if self.inference_global_siuv: assert not self.training if self.training: features = [ self.decoder(features, iuv_feats, rel_coord, abs_coord, fg_mask, ins_mask_list) ] features_dp_ori = features[0] proposal_boxes = [x.gt_boxes for x in gt_instances] features_dp = self.densepose_pooler(features, proposal_boxes) iuv_logits = features_dp # iuv_logit_global = features[0] return None, iuv_logits, features_dp_ori else: features = [ self.decoder(features, iuv_feats, rel_coord, abs_coord, fg_mask, ins_mask_list) ] # pdb.set_trace() features_dp_ori = features[0] if self.inference_global_siuv: iuv_logits = features[0] coarse_segm = s_logits else: # if isinstance(instances,Instances): # if self.use_gt_ins: # proposal_boxes = [x.gt_boxes for x in gt_instances] # else: proposal_boxes = [instances.pred_boxes] # else: # proposal_boxes = [x.pred_boxes for x in instances] features_dp = self.densepose_pooler(features, proposal_boxes) # pdb.set_trace() s_logit_list = [] for idx in range(s_logits.shape[0]): s_logit = self.densepose_pooler( [s_logits[idx:idx + 1]], [proposal_boxes[0][idx:idx + 1]]) s_logit_list.append(s_logit) coarse_segm = torch.cat(s_logit_list, dim=0) # iuv_logit = torch.cat([torch.cat(s_logit_list,dim=0), features_dp], dim=1) # iuv_logit_global = features[0] iuv_logits = features_dp # print(instances.pred_boxes) # else: # features = [self.decoder(features, iuv_feats, rel_coord, abs_coord, fg_mask, ins_mask_list)] # proposal_boxes = [instances.pred_boxes] # features_dp = self.densepose_pooler(features, proposal_boxes) # iuv_logit = features_dp # iuv_logit_global = features[0] return coarse_segm, iuv_logits, features_dp_ori