Exemple #1
0
    def forward(self, iuv_feats, iuv_feat_stride, instances=None):
        iuv_logit = self.tower(iuv_feats)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Exemple #2
0
    def forward(self,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                mask_out_bg=False):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        coord = rel_coord

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
            coord = torch.cat([abs_coord, coord], dim=1)

        if mask_out_bg:
            fg_mask = s_logits.detach()
            fg_mask_list = []
            for i in range(N):
                fg_mask_list.append(
                    torch.max(fg_mask[instances.im_inds == i],
                              dim=0,
                              keepdim=True)[0])
            fg_mask = torch.cat(fg_mask_list, dim=0).detach()
            # if mask_out_bg_feats=="hard":
            fg_mask = (fg_mask > 0.05).float()
            fg_mask = self._torch_dilate(fg_mask, kernel_size=3)
        else:
            fg_mask = torch.ones([N, 1, H, W], device=s_logits.device)

        x = iuv_feats
        for layer in self.layers:
            # pdb.set_trace()
            x = layer(torch.cat([coord, x], dim=1) * fg_mask)
        iuv_logit = x

        # iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Exemple #3
0
    def postprocess(self,
                    results,
                    output_height,
                    output_width,
                    padded_im_h,
                    padded_im_w,
                    mask_threshold=0.5):
        """
        Resize the output instances.
        The input images are often resized when entering an object detector.
        As a result, we often need the outputs of the detector in a different
        resolution from its inputs.
        This function will resize the raw outputs of an R-CNN detector
        to produce outputs according to the desired output resolution.
        Args:
            results (Instances): the raw outputs from the detector.
                `results.image_size` contains the input image resolution the detector sees.
                This object might be modified in-place.
            output_height, output_width: the desired output resolution.
        Returns:
            Instances: the resized output from the model, based on the output resolution
        """
        scale_x, scale_y = (output_width / results.image_size[1],
                            output_height / results.image_size[0])
        resized_im_h, resized_im_w = results.image_size
        results = Instances((output_height, output_width),
                            **results.get_fields())

        if results.has("pred_boxes"):
            output_boxes = results.pred_boxes
        elif results.has("proposal_boxes"):
            output_boxes = results.proposal_boxes

        output_boxes.scale(scale_x, scale_y)
        output_boxes.clip(results.image_size)

        results = results[output_boxes.nonempty()]

        if results.has("pred_global_masks"):
            mask_h, mask_w = results.pred_global_masks.size()[-2:]
            factor_h = padded_im_h // mask_h
            factor_w = padded_im_w // mask_w
            assert factor_h == factor_w
            factor = factor_h
            pred_global_masks = aligned_bilinear(results.pred_global_masks,
                                                 factor)
            pred_global_masks = pred_global_masks[:, :, :resized_im_h, :
                                                  resized_im_w]
            pred_global_masks = F.interpolate(pred_global_masks,
                                              size=(output_height,
                                                    output_width),
                                              mode="bilinear",
                                              align_corners=False)
            pred_global_masks = pred_global_masks[:, 0, :, :]
            results.pred_masks = (pred_global_masks > mask_threshold).float()

        return results
Exemple #4
0
    def forward(self,
                fpn_features,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                fg_mask=None,
                gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1)

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)

        iuv_head_inputs0 = iuv_head_inputs
        iuv_logit0 = self.tower0(iuv_head_inputs0)
        iuv_head_inputs1 = F.avg_pool2d(iuv_head_inputs0,
                                        kernel_size=3,
                                        stride=2)
        iuv_logit1 = self.tower1(iuv_head_inputs1)
        iuv_logit1 = F.interpolate(iuv_logit1, size=iuv_logit0.shape[-2:])
        iuv_head_inputs2 = F.avg_pool2d(iuv_head_inputs1,
                                        kernel_size=3,
                                        stride=2)
        iuv_logit2 = self.tower2(iuv_head_inputs2)
        iuv_logit2 = F.interpolate(iuv_logit2, size=iuv_logit0.shape[-2:])

        # attn = F.softmax(self.tower_attn(rel_coord), dim=1)
        # pdb.set_trace()
        # iuv_logit = iuv_logit0*attn[:,0:1] + iuv_logit1*attn[:,1:2] + iuv_logit2*attn[:,2:3]
        iuv_logit = torch.cat([iuv_logit0, iuv_logit1, iuv_logit2], dim=1)

        iuv_logit = self.tower_out(iuv_logit)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Exemple #5
0
    def forward(self, s_logits, iuv_feats, iuv_feat_stride, instances):

        locations = compute_locations(
            iuv_feats.size(2), iuv_feats.size(3),
            stride=iuv_feat_stride, device=iuv_feats.device
        )
        # n_inst = len(instances)

        im_inds = instances.im_inds

        N, _, H, W = iuv_feats.size()
        rel_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype)

        if not self.disable_rel_coords: 
            instance_locations = instances.locations
            relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=iuv_feats.dtype)
            # rel_coord_list = []
            for idx in range(N):
                if idx in im_inds:
                    cc = relative_coords[im_inds==idx,].reshape(-1, 2, H, W)
                    # assert s_logits.shape[1]==1
                    ss = s_logits[im_inds==idx,-1:]
                    # coord = torch.sum(cc*ss, dim=0, keepdim=True) \
                    #       / (torch.sum(ss, dim=0, keepdim=True)+1e-7)
                    coord = torch.mean(cc*ss, dim=0, keepdim=True) 
                    rel_coord[idx:idx+1] = coord #.reshape(1, 2, H, W)
                    # pdb.set_trace()
                    # import imageio
                    # imageio.imwrite("tmp/cc.png",cc[0,0].detach().cpu().numpy())
                    # imageio.imwrite("tmp/ss.png",ss[0,0].detach().cpu().numpy())
                    # imageio.imwrite("tmp/cc_ss.png",(cc*ss)[0,0].detach().cpu().numpy())
                    # imageio.imwrite("tmp/ss_sum.png",torch.sum(ss, dim=0, keepdim=True)[0,0].detach().cpu().numpy())
                    # imageio.imwrite("tmp/coord_mean.png",coord[0,0].detach().cpu().numpy())
                # rel_coord_list.append(rel_coord)
            # pdb.set_trace()
            iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) 
        else:
            iuv_head_inputs = iuv_feats





        iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
    def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride,
                                       instances):
        locations = compute_locations(mask_feats.size(2),
                                      mask_feats.size(3),
                                      stride=mask_feat_stride,
                                      device=mask_feats.device)
        n_inst = len(instances)

        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            instance_locations = instances.locations
            relative_coords = instance_locations.reshape(
                -1, 1, 2) - locations.reshape(1, -1, 2)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(
                    n_inst, self.in_channels, H * W)
            ],
                                         dim=1)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(
                n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)

        weights, biases = parse_dynamic_params(mask_head_params, self.channels,
                                               self.weight_nums,
                                               self.bias_nums,
                                               self.out_logit_dim)

        mask_logits = self.mask_heads_forward(mask_head_inputs, weights,
                                              biases, n_inst)

        # mask_logits = mask_logits.reshape(-1, 1, H, W)
        mask_logits = mask_logits.reshape(-1, self.out_logit_dim, H, W)

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(
            mask_logits, int(mask_feat_stride / self.mask_out_stride))

        # return mask_logits.sigmoid()
        return mask_logits[:, :-1, :, :], mask_logits[:, -1:, :, :].sigmoid()
    def forward(self,
                fpn_features,
                s_logits,
                iuv_feats,
                iuv_feat_stride,
                rel_coord,
                instances,
                fg_mask=None,
                gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords:
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1)

        if self.use_abs_coords:
            abs_coord = compute_grid(
                H, W, device=iuv_feats.device)[None, ...].repeat(N, 1, 1, 1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros(
                [N, 2, H, W],
                device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)
        iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(
            iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
    def forward(self, features, gt_instances=None):
        for i, f in enumerate(self.in_features):
            if i == 0:
                x = self.refine[i](features[f])
            else:
                x_p = self.refine[i](features[f])

                target_h, target_w = x.size()[2:]
                h, w = x_p.size()[2:]
                assert target_h % h == 0
                assert target_w % w == 0
                factor_h, factor_w = target_h // h, target_w // w
                assert factor_h == factor_w
                x_p = aligned_bilinear(x_p, factor_h)
                x = x + x_p

        mask_feats = self.tower(x)

        if self.num_outputs == 0:
            mask_feats = mask_feats[:, :self.num_outputs]

        losses = {}
        # auxiliary thing semantic loss
        if self.training and self.sem_loss_on:
            logits_pred = self.logits(
                self.seg_head(features[self.in_features[0]]))

            # compute semantic targets
            semantic_targets = []
            for per_im_gt in gt_instances:
                h, w = per_im_gt.gt_bitmasks_full.size()[-2:]
                areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1)
                areas = areas[:, None, None].repeat(1, h, w)
                areas[per_im_gt.gt_bitmasks_full == 0] = INF
                areas = areas.permute(1, 2, 0).reshape(h * w, -1)
                min_areas, inds = areas.min(dim=1)
                per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1
                per_im_sematic_targets[min_areas == INF] = 0
                per_im_sematic_targets = per_im_sematic_targets.reshape(h, w)
                semantic_targets.append(per_im_sematic_targets)

            semantic_targets = torch.stack(semantic_targets, dim=0)

            # resize target to reduce memory
            semantic_targets = semantic_targets[:, None, self.out_stride //
                                                2::self.out_stride,
                                                self.out_stride //
                                                2::self.out_stride]

            # prepare one-hot targets
            num_classes = logits_pred.size(1)
            class_range = torch.arange(num_classes,
                                       dtype=logits_pred.dtype,
                                       device=logits_pred.device)[:, None,
                                                                  None]
            class_range = class_range + 1
            one_hot = (semantic_targets == class_range).float()
            num_pos = (one_hot > 0).sum().float().clamp(min=1.0)

            loss_sem = sigmoid_focal_loss_jit(
                logits_pred,
                one_hot,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_pos
            losses['loss_sem'] = loss_sem

        return mask_feats, losses
Exemple #9
0
    def forward(self, fpn_features, s_logits, iuv_feats, iuv_feat_stride, rel_coord, instances, fg_mask, gt_instances=None):
        N, _, H, W = iuv_feats.size()

        if self.use_rel_coords: 
            if self.use_pos_emb:
                rel_coord = self.position_embedder(rel_coord)
        else:
            rel_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([rel_coord, iuv_feats], dim=1) 

        if self.use_abs_coords: 
            abs_coord = compute_grid(H, W, device=iuv_feats.device)[None,...].repeat(N,1,1,1)
            if self.use_pos_emb:
                abs_coord = self.position_embedder(abs_coord)
        else:
            abs_coord = torch.zeros([N,2,H,W], device=iuv_feats.device).to(dtype=iuv_feats.dtype)
        iuv_head_inputs = torch.cat([abs_coord, iuv_head_inputs], dim=1)

        # fg_mask = s_logits.detach()
        # fg_mask_list = []
        # for i in range(N):
        #     fg_mask_list.append(torch.max(fg_mask[instances.im_inds==i], dim=0, keepdim=True)[0])
        # fg_mask = torch.cat(fg_mask_list, dim=0).detach()
        # # if mask_out_bg_feats=="hard":
        # fg_mask = (fg_mask>0.05).float()
        # # fg_mask = self._torch_dilate(fg_mask, kernel_size=3)


        fg_mask = self._torch_dilate(fg_mask, kernel_size=3)

        # pdb.set_trace()
        # import imageio
        # imageio.imwrite("tmp/fg_mask_dilate5.png",fg_mask[0,0].detach().cpu().numpy())

        x = iuv_head_inputs
        
        if self.use_partial_norm:
            for layer in self.layers:
                if isinstance(layer,Conv2d) or isinstance(layer,PartialConv2d):
                    # x = layer(x*fg_mask)
                    x = layer(x)
                elif isinstance(layer,nn.GroupNorm):
                    fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True)[:,None,...]
                    "Implement partial GN"
                    x = x*fg_mask
                    n,c,h,w = x.shape
                    # mid_layer = [t for t in layer.named_children()][1][1]
                    # assert isinstance(mid_layer,nn.GroupNorm)
                    num_groups = layer.num_groups
                    x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2)

                    x_group_mean = torch.mean(x_group, dim=[-3,-2,-1], keepdim=True)
                    x_group_std = torch.std(x_group, dim=[-3,-2,-1], keepdim=True)
                    x_group_mean = x_group_mean.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])
                    x_group_std = x_group_std.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])

                    x_group_mean_p = torch.sum(x_group, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum
                    x_group_std_p = torch.sqrt(torch.sum((x_group-x_group_mean_p)**2+1e-5, dim=[-3,-2,-1], keepdim=True)/fg_mask_sum)
                    x_group_mean_p = x_group_mean_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])
                    x_group_std_p = x_group_std_p.repeat(1,1,num_groups,1,1).reshape([n,c,1,1])

                    gamma, beta = layer.parameters()
                    gamma, beta = gamma[None,...,None,None], beta[None,...,None,None]

                    # pdb.set_trace()
                    x = layer(x)
                    x = (x - beta) / gamma * x_group_std + x_group_mean
                    x = (x - x_group_mean_p) / x_group_std_p * gamma + beta
                    (x - x_group_mean) / x_group_std * gamma + beta

                    x = (x - x_group_mean_p) / x_group_std_p * gamma + beta

                    # x = (x-beta)/gamma fg_mask_sum + beta
                elif isinstance(layer,nn.BatchNorm2d):
                    fg_mask_sum = fg_mask.sum(dim=[0,-2,-1], keepdim=True)
                    # "Implement partial BN"
                    "Implement bbox BN"
                    # x = x*fg_mask
                    n,c,h,w = x.shape
                    # mid_layer = [t for t in layer.named_children()][1][1]
                    # assert isinstance(mid_layer,nn.GroupNorm)
                    # num_groups = layer.num_groups
                    # x_group = torch.stack(torch.chunk(x, num_groups, dim=1), dim=2)

                    # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True)
                    # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True)

                    x_mean_p = torch.sum(x*fg_mask, dim=[0,-2,-1], keepdim=True)/fg_mask_sum
                    x_std_p = torch.sqrt(torch.sum((x*fg_mask-x_mean_p)**2+1e-5, dim=[0,-2,-1], keepdim=True)/fg_mask_sum)

                    gamma, beta = layer.parameters()
                    gamma, beta = gamma[None,...,None,None], beta[None,...,None,None]

                    # x = layer(x)
                    # x = (x - beta) / gamma * x_std + x_mean
                    # x = (x - x_mean_p) / x_std_p * gamma + beta

                    # pdb.set_trace() 
                    x = (x - x_mean_p) / x_std_p * gamma + beta

                    # x_mean = torch.mean(x, dim=[0,-2,-1], keepdim=True)
                    # x_std = torch.std(x, dim=[0,-2,-1], keepdim=True)
                    # x = (x - x_mean) / x_std * gamma + beta

                    # x = layer(x)

                    # print(gamma.mean(), beta.mean())

                    # x = (x-beta)/gamma fg_mask_sum + beta
                else:
                    x = layer(x)
        else:
            for layer in self.layers:
                if isinstance(layer,LambdaLayer):
                    x = layer(x)
                else:
                    x = layer(x*fg_mask)

        iuv_logit = x
        # iuv_logit = x*fg_mask

        # iuv_logit = self.tower(iuv_head_inputs)

        assert iuv_feat_stride >= self.iuv_out_stride
        assert iuv_feat_stride % self.iuv_out_stride == 0
        iuv_logit = aligned_bilinear(iuv_logit, int(iuv_feat_stride / self.iuv_out_stride))

        return iuv_logit
Exemple #10
0
    def __call__(self,
                 iuv_head_func,
                 iuv_feats,
                 mask_feats,
                 mask_feat_stride,
                 pred_instances,
                 gt_instances=None):
        if self.training:
            gt_inds = pred_instances.gt_inds
            gt_bitmasks = torch.cat(
                [per_im.gt_bitmasks for per_im in gt_instances])
            gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(
                dtype=mask_feats.dtype)

            losses = {}
            if len(pred_instances) == 0:
                loss_mask = mask_feats.sum(
                ) * 0 + pred_instances.mask_head_params.sum() * 0
                losses["loss_mask"] = loss_mask.float()
                losses["loss_densepose_I"] = mask_feats.sum() * 0
                losses["loss_densepose_U"] = mask_feats.sum() * 0
                losses["loss_densepose_V"] = mask_feats.sum() * 0
                losses["loss_densepose_S"] = mask_feats.sum() * 0
            else:
                s_logits = self.mask_heads_forward_with_coords(
                    mask_feats, mask_feat_stride, pred_instances)
                if self.n_segm_chan == 1:
                    s_logits = s_logits.sigmoid()
                elif self.n_segm_chan == 3:
                    s_logits = s_logits[:, :1].sigmoid()
                else:
                    raise NotImplementedError

                iuv_logits = iuv_head_func(s_logits.detach(), iuv_feats,
                                           mask_feat_stride, pred_instances)

                assert mask_feat_stride >= self.mask_out_stride
                assert mask_feat_stride % self.mask_out_stride == 0
                s_logits = aligned_bilinear(
                    s_logits, int(mask_feat_stride / self.mask_out_stride))

                densepose_outputs = DensePoseChartPredictorOutput(
                    coarse_segm=s_logits,
                    fine_segm=iuv_logits[:, :25],
                    u=iuv_logits[:, 25:50],
                    v=iuv_logits[:, 50:75],
                )
                for i in range(len(gt_instances)):
                    gt_instances[i].set(
                        'proposal_boxes',
                        gt_instances[i].get('gt_boxes').clone())

                densepose_loss_dict = self.densepose_losses(
                    gt_instances, densepose_outputs, gt_bitmasks)
                losses.update(densepose_loss_dict)

            return losses

        else:
            if len(pred_instances) > 0:
                s_logits = self.mask_heads_forward_with_coords(
                    mask_feats, mask_feat_stride, pred_instances)

                if self.n_segm_chan == 1:
                    "To mimic 2 channels segmentation during inference"
                    s_logits = s_logits.sigmoid()
                    s_logits = torch.cat([1 - s_logits, s_logits], dim=1)
                elif self.n_segm_chan == 3:
                    s_logits = s_logits[:, :1].sigmoid()
                    s_logits = torch.cat([1 - s_logits, s_logits], dim=1)
                else:
                    raise NotImplementedError

                iuv_logits = iuv_head_func(s_logits, iuv_feats,
                                           mask_feat_stride, pred_instances)

                assert mask_feat_stride >= self.mask_out_stride
                assert mask_feat_stride % self.mask_out_stride == 0
                s_logits = aligned_bilinear(
                    s_logits, int(mask_feat_stride / self.mask_out_stride))

                densepose_outputs = DensePoseChartPredictorOutput(
                    coarse_segm=s_logits,
                    fine_segm=iuv_logits[:, :25],
                    u=iuv_logits[:, 25:50],
                    v=iuv_logits[:, 50:75],
                )
            else:
                densepose_outputs = None
            pred_instances = convert_condInst_to_densepose_inference(
                densepose_outputs, pred_instances, size=(256, 256))
            return pred_instances, densepose_outputs