Пример #1
0
    def forward(self,
                features: List[torch.Tensor],
                iuv_feats: torch.Tensor,
                rel_coord: Any,
                abs_coord: Any,
                fg_mask: Any,
                ins_mask_list=None):
        for i, _ in enumerate(self.in_features):
            if i == 0:
                x = self.scale_heads[i](features[i])
            else:
                x = x + self.scale_heads[i](features[i])
        if rel_coord is not None:
            x = torch.cat([x, rel_coord], dim=1)
        if abs_coord is not None:
            x = torch.cat([x, abs_coord], dim=1)
        # if skeleton_feats is not None:
        #     x = torch.cat([x,skeleton_feats], dim=1)

        # pdb.set_trace()
        if rel_coord is not None or abs_coord is not None:
            x = self.comb_pe_conv(x)
        x = x * fg_mask

        if self.use_ins_gn:
            ## dense to sparse
            N, C, H, W = x.shape
            coord = compute_grid(H, W, device=x.device, norm=False)
            # sparse_coord_batch = []
            # sparse_feat_batch = []
            ins_indices_batch = []
            # ins_indices_len = []
            # ins_cnt = 0
            for n in range(N):
                # m = fg_mask[n:n+1]
                x_indices = coord[0]
                y_indices = coord[1]
                # pdb.set_trace()
                # bg_and_ins = torch.cat([m[0],ins_mask_list[n].float()], dim=0)
                # ins_indices = torch.argmax(bg_and_ins, dim=0)[m[0,0]>0] + ins_cnt
                # try:
                # pdb.set_trace()
                logit_bg_fg = torch.cat(
                    [(1 - fg_mask[n]) * 99999., ins_mask_list[n].float()],
                    dim=0)
                ins_indices = torch.argmax(logit_bg_fg,
                                           dim=0) - 1  ## set bg to -1
                ins_indices[ins_indices >= 0] = ins_indices[ins_indices >=
                                                            0]  #+ ins_cnt
                ins_indices_batch.append(ins_indices)
                # ins_cnt += ins_mask_list[n].shape[0] - 1 ## exclude bg class

            ins_indices_batch = torch.stack(ins_indices_batch, dim=0)

            x = self.densepose_head(x, ins_indices_batch)
        else:
            x = self.densepose_head(x)
        x = self.predictor(x)

        return x
Пример #2
0
    def forward(self,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                mask_out_bg=False):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        coord = rel_coord

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
            coord = torch.cat([abs_coord, coord], dim=1)

        if mask_out_bg:
            fg_mask = s_logits.detach()
            fg_mask_list = []
            for i in range(N):
                fg_mask_list.append(
                    torch.max(fg_mask[instances.im_inds == i],
                              dim=0,
                              keepdim=True)[0])
            fg_mask = torch.cat(fg_mask_list, dim=0).detach()
            # if mask_out_bg_feats=="hard":
            fg_mask = (fg_mask > 0.05).float()
            fg_mask = self._torch_dilate(fg_mask, kernel_size=3)
        else:
            fg_mask = torch.ones([N, 1, H, W], device=s_logits.device)

        x = iuv_feats
        for layer in self.layers:
            # pdb.set_trace()
            x = layer(torch.cat([coord, x], dim=1) * fg_mask)
        iuv_logit = x

        # iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Пример #3
0
    def forward(self,
                fpn_features,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                fg_mask=None,
                gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1)

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)

        iuv_head_inputs0 = iuv_head_inputs
        iuv_logit0 = self.tower0(iuv_head_inputs0)
        iuv_head_inputs1 = F.avg_pool2d(iuv_head_inputs0,
                                        kernel_size=3,
                                        stride=2)
        iuv_logit1 = self.tower1(iuv_head_inputs1)
        iuv_logit1 = F.interpolate(iuv_logit1, size=iuv_logit0.shape[-2:])
        iuv_head_inputs2 = F.avg_pool2d(iuv_head_inputs1,
                                        kernel_size=3,
                                        stride=2)
        iuv_logit2 = self.tower2(iuv_head_inputs2)
        iuv_logit2 = F.interpolate(iuv_logit2, size=iuv_logit0.shape[-2:])

        # attn = F.softmax(self.tower_attn(rel_coord), dim=1)
        # pdb.set_trace()
        # iuv_logit = iuv_logit0*attn[:,0:1] + iuv_logit1*attn[:,1:2] + iuv_logit2*attn[:,2:3]
        iuv_logit = torch.cat([iuv_logit0, iuv_logit1, iuv_logit2], dim=1)

        iuv_logit = self.tower_out(iuv_logit)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Пример #4
0
    def forward(self,
                fpn_features,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                fg_mask=None,
                gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1)

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)
        iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
    def forward(self, features: List[torch.Tensor], iuv_feats: torch.Tensor, rel_coord: torch.Tensor, 
                abs_coord: torch.Tensor, fg_mask: torch.Tensor, ins_mask_list: List[torch.Tensor]):
        # assert fg_mask.min()==0, "the fg_mask is all 1"
        if not self.use_agg_feat:
            for i, _ in enumerate(self.in_features):
                if i == 0:
                    x = self.scale_heads[i](features[i])
                else:
                    x = x + self.scale_heads[i](features[i])
        else:
            x = iuv_feats
        if rel_coord is not None:
            x = torch.cat([x,rel_coord], dim=1)
        if abs_coord is not None:
            x = torch.cat([x,abs_coord], dim=1)
        # pdb.set_trace()
        # if skeleton_feats is not None:
        #     x = torch.cat([x,skeleton_feats], dim=1)

        if rel_coord is not None or abs_coord is not None:
            # if isinstance(self.comb_pe_conv, LambdaLayer):
            #     x = self.comb_pe_conv(x)
            # else:
            x = self.comb_pe_conv(x)

        if self.use_dropout:
            x = self.dropout_layer(x)

        if self.use_san:
            # x = self.san_blk_1(x*fg_mask)
            x = self.san_blk_1(x) 

        ## dense to sparse
        N, C, H, W = x.shape
        coord = compute_grid(H, W, device=x.device, norm=False)
        sparse_coord_batch = []
        sparse_feat_batch = []
        ins_indices_batch = []
        ins_indices_len = []
        ins_cnt = 0
        for n in range(N):
            m = fg_mask[n:n+1]
            x_indices = coord[0][m[0,0]>0]
            y_indices = coord[1][m[0,0]>0]
            # if self.use_ins_gn:
            ins_indices = torch.argmax(ins_mask_list[n].float(), dim=0)[m[0,0]>0] + ins_cnt
            ins_indices_batch.append(ins_indices)
            ins_cnt += ins_mask_list[n].shape[0]
            ins_indices_len.append(torch.sum(ins_mask_list[n],dim=[1,2]))


            b_indices = torch.ones_like(x_indices)*n
            sparse_coord = torch.stack([b_indices,y_indices,x_indices],dim=-1).int()
            sparse_coord_batch.append(sparse_coord)
            sparse_feat = x[n:n+1]
            sparse_feat = sparse_feat[m.expand_as(sparse_feat)>0].reshape([C,-1]).permute([1,0])
            sparse_feat_batch.append(sparse_feat)
        sparse_coord_batch = torch.cat(sparse_coord_batch,dim=0)
        sparse_feat_batch = torch.cat(sparse_feat_batch,dim=0)
        # pdb.set_trace()
        # if self.use_ins_gn:
        #     x = spconv.SparseConvTensor(sparse_feat_batch, sparse_coord_batch, (H,W), ins_cnt)
        # else:
        x = spconv.SparseConvTensor(sparse_feat_batch, sparse_coord_batch, (H,W), N)
        # pdb.set_trace()
        # if self.use_ins_gn:
        ins_indices_batch = torch.cat(ins_indices_batch,dim=0)
        ins_indices_len = torch.cat(ins_indices_len,dim=0)
        x = self.densepose_head(x, ins_indices_batch, ins_indices_len)
        # else:
        #     x = self.densepose_head(x)

        # x = x * fg_mask
        # x = self.densepose_head(x, ins_mask_list)
        if self.predictor_conv_type=="sparse":
            x = self.predictor(x).dense()
        else:
            # pdb.set_trace()
            x = x.dense()

            # if self.checkpoint_grad_num>0 and len(self.bbox_tower)>0:
            #     modules = [module for k, module in self.bbox_tower._modules.items()]
            #     bbox_tower = checkpoint.checkpoint_sequential(modules,1,feature)
            # else:
            #     bbox_tower = self.bbox_tower(feature)

            if self.checkpoint_grad_num>0:
                x = checkpoint.checkpoint(self.custom(self.predictor), x)
            else:
                x = self.predictor(x)

        return x
Пример #6
0
    def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask, gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords: 
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) 

        if self.use_abs_coords: 
            abs_coord = compute_grid(H, W, device=iuv_feats.device)[None,...].repeat(N,1,1,1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)

        # fg_mask = s_logits.detach()
        # fg_mask_list = []
        # for i in range(N):
        #     fg_mask_list.append(torch.max(fg_mask[instances.im_inds==i], dim=0, keepdim=True)[0])
        # fg_mask = torch.cat(fg_mask_list, dim=0).detach()
        # # if mask_out_bg_feats=="hard":
        # fg_mask = (fg_mask>0.05).float()
        # # fg_mask = self._torch_dilate(fg_mask, kernel_size=3)


        fg_mask = self._torch_dilate(fg_mask, kernel_size=3)

        # pdb.set_trace()
        # import imageio
        # imageio.imwrite("tmp/fg_mask_dilate5.png",fg_mask[0,0].detach().cpu().numpy())

        x = iuv_head_inputs
        
        if self.use_partial_norm:
            for layer in self.layers:
                if isinstance(layer,Conv2d) or isinstance(layer,PartialConv2d):
                    # x = layer(x*fg_mask)
                    x = layer(x)
                elif isinstance(layer,nn.GroupNorm):
                    fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True)[:,None,...]
                    "Implement partial GN"
                    x = x*fg_mask
                    n,c,h,w = x.shape
                    # mid_layer = [t for t in layer.named_children()][1][1]
                    # assert isinstance(mid_layer,nn.GroupNorm)
                    num_groups = layer.num_groups
                    x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2)

                    x_group_mean = torch.mean(x_group, dim=[-3,-2,-1], keepdim=True)
                    x_group_std = torch.std(x_group, dim=[-3,-2,-1], keepdim=True)
                    x_group_mean = x_group_mean.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])
                    x_group_std = x_group_std.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])

                    x_group_mean_p = torch.sum(x_group, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum
                    x_group_std_p = torch.sqrt(torch.sum((x_group-x_group_mean_p)**2+1e-5, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum)
                    x_group_mean_p = x_group_mean_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])
                    x_group_std_p = x_group_std_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])

                    gamma, beta = layer.parameters()
                    gamma, beta = gamma[None,...,None,None], beta[None,...,None,None]

                    # pdb.set_trace()
                    x = layer(x)
                    x = (x - beta) / gamma * x_group_std + x_group_mean
                    x = (x - x_group_mean_p) / x_group_std_p * gamma + beta
                    (x - x_group_mean) / x_group_std * gamma + beta

                    x = (x - x_group_mean_p) / x_group_std_p * gamma + beta

                    # x = (x-beta)/gamma fg_mask_sum + beta
                elif isinstance(layer,nn.BatchNorm2d):
                    fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True)
                    # "Implement partial BN"
                    "Implement bbox BN"
                    # x = x*fg_mask
                    n,c,h,w = x.shape
                    # mid_layer = [t for t in layer.named_children()][1][1]
                    # assert isinstance(mid_layer,nn.GroupNorm)
                    # num_groups = layer.num_groups
                    # x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2)

                    # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True)
                    # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True)

                    x_mean_p = torch.sum(x*fg_mask, dim=[0,-2,-1], keepdim=True)/fg_mask_sum
                    x_std_p = torch.sqrt(torch.sum((x*fg_mask-x_mean_p)**2+1e-5, dim=[0,-2,-1], keepdim=True)/fg_mask_sum)

                    gamma, beta = layer.parameters()
                    gamma, beta = gamma[None,...,None,None], beta[None,...,None,None]

                    # x = layer(x)
                    # x = (x - beta) / gamma * x_std + x_mean
                    # x = (x - x_mean_p) / x_std_p * gamma + beta

                    # pdb.set_trace() 
                    x = (x - x_mean_p) / x_std_p * gamma + beta

                    # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True)
                    # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True)
                    # x = (x - x_mean) / x_std * gamma + beta

                    # x = layer(x)

                    # print(gamma.mean(), beta.mean())

                    # x = (x-beta)/gamma fg_mask_sum + beta
                else:
                    x = layer(x)
        else:
            for layer in self.layers:
                if isinstance(layer,LambdaLayer):
                    x = layer(x)
                else:
                    x = layer(x*fg_mask)

        iuv_logit = x
        # iuv_logit = x*fg_mask

        # iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Пример #7
0
    def forward(self,
                fpn_features,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                fg_mask,
                gt_instances=None,
                ins_mask_list=None):
        # assert not self.use_abs_coords

        fea0 = fpn_features[self.in_features[0]]
        N, _, H, W = fea0.shape

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = None

        if self.use_abs_coords:
            abs_coord = compute_grid(H, W, device=fea0.device)[None,
                                                               ...].repeat(
                                                                   N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = None

        features = [fpn_features[f] for f in self.in_features]

        if self.inference_global_siuv:
            assert not self.training

        if self.training:
            features = [
                self.decoder(features, iuv_feats, rel_coord, abs_coord,
                             fg_mask, ins_mask_list)
            ]
            features_dp_ori = features[0]
            proposal_boxes = [x.gt_boxes for x in gt_instances]
            features_dp = self.densepose_pooler(features, proposal_boxes)
            iuv_logits = features_dp
            # iuv_logit_global = features[0]
            return None, iuv_logits, features_dp_ori
        else:
            features = [
                self.decoder(features, iuv_feats, rel_coord, abs_coord,
                             fg_mask, ins_mask_list)
            ]
            # pdb.set_trace()
            features_dp_ori = features[0]

            if self.inference_global_siuv:
                iuv_logits = features[0]
                coarse_segm = s_logits
            else:
                # if isinstance(instances,Instances):
                # if self.use_gt_ins:
                #     proposal_boxes = [x.gt_boxes for x in gt_instances]
                # else:
                proposal_boxes = [instances.pred_boxes]
                # else:
                #     proposal_boxes = [x.pred_boxes for x in instances]
                features_dp = self.densepose_pooler(features, proposal_boxes)
                # pdb.set_trace()
                s_logit_list = []
                for idx in range(s_logits.shape[0]):
                    s_logit = self.densepose_pooler(
                        [s_logits[idx:idx + 1]],
                        [proposal_boxes[0][idx:idx + 1]])
                    s_logit_list.append(s_logit)
                coarse_segm = torch.cat(s_logit_list, dim=0)
                # iuv_logit = torch.cat([torch.cat(s_logit_list,dim=0), features_dp], dim=1)
                # iuv_logit_global = features[0]
                iuv_logits = features_dp
                # print(instances.pred_boxes)
        # else:
        #     features = [self.decoder(features, iuv_feats, rel_coord, abs_coord, fg_mask, ins_mask_list)]
        #     proposal_boxes = [instances.pred_boxes]
        #     features_dp = self.densepose_pooler(features, proposal_boxes)
        #     iuv_logit = features_dp
        #     iuv_logit_global = features[0]

            return coarse_segm, iuv_logits, features_dp_ori