Ejemplo n.º 1
0
    def forward(self, data):
        return_dict = {}
        return_dict['visualization'] = {}
        return_dict['losses'] = {}

        if cfg.DANET.INPUT_MODE == 'rgb':
            para, _ = self.Conv_Body(data)
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            if data.size(1) == 3:
                Umap, Vmap, Imap, _ = iuv_img2map(data)
                iuv_map = torch.cat([Umap, Vmap, Imap], dim=1)
            else:
                iuv_map = data
            para, _ = self.Conv_Body(iuv_map)
        elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            para, _ = self.Conv_Body(
                torch.cat([data['iuv'], data['feat']], dim=1))
        elif cfg.DANET.INPUT_MODE == 'feat':
            para, _ = self.Conv_Body(data['feat'])
        elif cfg.DANET.INPUT_MODE == 'seg':
            para, _ = self.Conv_Body(data['index'])

        return_dict['para'] = para

        return return_dict
Ejemplo n.º 2
0
    def infer_net(self, image):
        """For inference"""
        return_dict = {}
        return_dict['visualization'] = {}

        if cfg.DANET.INPUT_MODE in ['iuv_gt']:
            if cfg.DANET.DECOMPOSED:
                uv_return_dict = self.img2iuv(image[0], iuv_image_gt=image[1], smpl_kps_gt=image[2])
                u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image[1])
            else:
                uv_return_dict = {}
                u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image)
        elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']:
            uv_return_dict = self.img2iuv(image[0])
            u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image[1])
        else:
            uv_return_dict = self.img2iuv(image)
            u_pred, v_pred, index_pred, ann_pred = iuvmap_clean(*uv_return_dict['uvia_pred'])

        return_dict['visualization']['iuv_pred'] = [u_pred, v_pred, index_pred, ann_pred]
        if 'part_iuv_pred' in uv_return_dict:
            return_dict['visualization']['part_iuv_pred'] = uv_return_dict['part_iuv_pred']

        iuv_map = torch.cat([u_pred, v_pred, index_pred], dim=1)

        if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict:
            part_iuv_map = uv_return_dict['part_iuv_gt']
            part_index_map = part_iuv_map[:, :, 2]
        elif 'part_iuv_pred' in uv_return_dict:
            part_iuv_pred = uv_return_dict['part_iuv_pred']
            part_iuv_map = []
            for p_ind in range(part_iuv_pred.size(1)):
                p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)]
                p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred)
                p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1)
                part_iuv_map.append(p_iuv_map)
            part_iuv_map = torch.stack(part_iuv_map, dim=1)
            part_index_map = part_iuv_map[:, :, 2].detach()
        else:
            part_iuv_map = None
            part_index_map = None

        if 'part_featmaps' in uv_return_dict:
            part_feat_map = uv_return_dict['part_featmaps']
        else:
            part_feat_map = None

        if cfg.DANET.INPUT_MODE == 'feat':
            smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'feat': uv_return_dict['global_featmaps']},
                                             'part_iuv_map': {'pfeat': part_feat_map}
                                                             })
        elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']},
                                             'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map}
                                                             })
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': iuv_map,
                                                            'part_iuv_map': part_iuv_map
                                                             })
        elif cfg.DANET.INPUT_MODE == 'seg':
            smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'index': index_pred.detach()},
                                             'part_iuv_map': {'pindex': part_index_map}
                                                             })

        return_dict['para'] = smpl_return_dict['para']

        for k, v in smpl_return_dict['visualization'].items():
            return_dict['visualization'][k] = v

        return return_dict
Ejemplo n.º 3
0
    def _forward(self, in_dict):

        if type(in_dict) is not dict:
            in_dict = {'img': in_dict, 'pretrain_mode': False, 'vis_on': False, 'dataset_name': ''}

        image = in_dict['img']
        gt_pose = in_dict['opt_pose'] if 'opt_pose' in in_dict else None  # SMPL pose parameters
        gt_betas = in_dict['opt_betas'] if 'opt_betas' in in_dict else None  # SMPL beta parameters
        target_kps = in_dict['target_kps'] if 'target_kps' in in_dict else None
        target_kps3d = in_dict['target_kps3d'] if 'target_kps3d' in in_dict else None
        has_iuv = in_dict['has_iuv'].byte() if 'has_iuv' in in_dict else None
        has_dp = in_dict['has_dp'].byte() if 'has_dp' in in_dict else None
        has_kp3d = in_dict['has_pose_3d'].byte() if 'has_pose_3d' in in_dict else None  # flag that indicates whether 3D pose is valid
        target_smpl_kps = in_dict['target_smpl_kps'] if 'target_smpl_kps' in in_dict else None
        target_verts = in_dict['target_verts'] if 'target_verts' in in_dict else None
        valid_fit = in_dict['valid_fit'] if 'valid_fit' in in_dict else None

        batch_size = image.shape[0]

        if gt_pose is not None:
            gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24 * 3 * 3)
            target_cam = in_dict['target_cam']
            target = torch.cat([target_cam, gt_betas, gt_rotmat], dim=1)
            uv_image_gt = torch.zeros((batch_size, 3, cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)).to(image.device)
            if torch.sum(has_iuv) > 0:
                uv_image_gt[has_iuv] = self.iuv2smpl.verts2uvimg(target_verts[has_iuv], cam=target_cam[has_iuv])  # [B, 3, 56, 56]
        else:
            target = None

        # target_iuv_dp = in_dict['target_iuv_dp'] if 'target_iuv_dp' in in_dict else None
        target_iuv_dp = in_dict['dp_dict'] if 'dp_dict' in in_dict else None

        if 'target_kps_coco' in in_dict:
            target_kps = in_dict['target_kps_coco']

        return_dict = {}
        return_dict['losses'] = {}
        return_dict['metrics'] = {}
        return_dict['visualization'] = {}
        return_dict['prediction'] = {}

        if cfg.DANET.INPUT_MODE in ['iuv_gt']:
            if cfg.DANET.DECOMPOSED:
                uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
            else:
                uv_return_dict = {}
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
            uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['feat']:
            uv_return_dict = self.img2iuv(image, None, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
        else:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp, has_iuv=has_iuv, has_dp=has_dp)

        u_pred, v_pred, index_pred, ann_pred = uv_return_dict['uvia_pred']
        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            zero_idxs = []
            for bs in range(u_pred.shape[0]):
                zero_idxs.append([int(i) + 1 for i in torch.nonzero(torch.rand(24) < cfg.DANET.PART_IUV_ZERO)])

        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            for bs in range(len(zero_idxs)):
                u_pred[bs, zero_idxs[bs]] *= 0
                v_pred[bs, zero_idxs[bs]] *= 0
                index_pred[bs, zero_idxs[bs]] *= 0

        u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred)

        iuv_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
        return_dict['visualization']['iuv_pred'] = iuv_pred_clean

        if in_dict['vis_on']:
            uvi_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
            return_dict['visualization']['pred_uv'] = iuv_map2img(*uvi_pred_clean)
            return_dict['visualization']['gt_uv'] = uv_image_gt
            if 'stn_kps_pred' in uv_return_dict:
                return_dict['visualization']['stn_kps_pred'] = uv_return_dict['stn_kps_pred']

            # index_pred_cl shape:  2, 25, 56, 56
            return_dict['visualization']['index_sum'] = [torch.sum(index_pred_cl[:, 1:].detach()).unsqueeze(0), np.prod(index_pred_cl[:, 0].shape)]

            for key in ['skps_hm_pred', 'skps_hm_gt']:
                if key in uv_return_dict:
                    return_dict['visualization'][key] = torch.max(uv_return_dict[key], dim=1)[0].unsqueeze(1)
                    return_dict['visualization'][key][return_dict['visualization'][key] > 1] = 1.
                    skps_hm_vis = uv_return_dict[key]
                    skps_hm_vis = skps_hm_vis.reshape((skps_hm_vis.shape[0], skps_hm_vis.shape[1], -1))
                    skps_hm_vis = F.softmax(skps_hm_vis, 2)
                    skps_hm_vis = skps_hm_vis.reshape(skps_hm_vis.shape[0], skps_hm_vis.shape[1],
                                                      cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)
                    return_dict['visualization'][key + '_soft'] = torch.sum(skps_hm_vis, dim=1).unsqueeze(1)
            # for key in ['part_uvi_pred', 'part_uvi_gt']:
            for key in ['part_uvi_gt']:
                if key in uv_return_dict:
                    part_uvi_pred_vis = uv_return_dict[key][0]
                    p_uvi_vis = []
                    for i in range(part_uvi_pred_vis.size(0)):
                        p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                        if p_u_vis.size(1) == 25:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                        else:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                         ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                        # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                        p_uvi_vis.append(p_uvi_vis_i)
                    return_dict['visualization'][key] = torch.cat(p_uvi_vis, dim=0)

        if not in_dict['pretrain_mode']:

            iuv_map = torch.cat([u_pred_cl, v_pred_cl, index_pred_cl], dim=1)

            if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict:
                part_iuv_map = uv_return_dict['part_iuv_gt']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_map[bs, zero_dp_i, :, zero_p_i] *= 0

                part_index_map = part_iuv_map[:, :, 2]
            elif 'part_iuv_pred' in uv_return_dict:
                part_iuv_pred = uv_return_dict['part_iuv_pred']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_pred[bs, zero_dp_i, :, zero_p_i] *= 0

                part_iuv_map = []
                for p_ind in range(part_iuv_pred.size(1)):
                    p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)]
                    p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred)
                    p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1)
                    part_iuv_map.append(p_iuv_map)
                part_iuv_map = torch.stack(part_iuv_map, dim=1)
                part_index_map = part_iuv_map[:, :, 2]

            else:
                part_iuv_map = None
                part_index_map = None

            return_dict['visualization']['part_iuv_pred'] = part_iuv_map

            if 'part_featmaps' in uv_return_dict:
                part_feat_map = uv_return_dict['part_featmaps']
            else:
                part_feat_map = None

            if cfg.DANET.INPUT_MODE == 'feat':
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': iuv_map,
                                                 'part_iuv_map': part_iuv_map,
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d,
                                                 'has_smpl': valid_fit
                                                  })
            elif cfg.DANET.INPUT_MODE == 'seg':
                # REMOVE _.detach
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'index': index_pred_cl},
                                                 'part_iuv_map': {'pindex': part_index_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })

            if in_dict['vis_on'] and part_index_map is not None:
                # part_index_map: 2, 24, 7, 56, 56
                return_dict['visualization']['p_index_sum'] = [torch.sum(part_index_map[:, :, 1:].detach()).unsqueeze(0),
                                                               np.prod(part_index_map[:, :, 0].shape)]

            if in_dict['vis_on'] and part_iuv_map is not None:
                part_uvi_pred_vis = part_iuv_map[0]
                p_uvi_vis = []
                for i in range(part_uvi_pred_vis.size(0)):
                    p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                    if p_u_vis.size(1) == 25:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                    else:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                     ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                    # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                    p_uvi_vis.append(p_uvi_vis_i)
                return_dict['visualization']['part_uvi_pred'] = torch.cat(p_uvi_vis, dim=0)

        for key_name in ['losses', 'metrics', 'visualization', 'prediction']:
            if key_name in uv_return_dict:
                return_dict[key_name].update(uv_return_dict[key_name])
            if not in_dict['pretrain_mode']:
                return_dict[key_name].update(smpl_return_dict[key_name])

        # pytorch0.4 bug on gathering scalar(0-dim) tensors
        for k, v in return_dict['losses'].items():
            if len(v.shape) == 0:
                return_dict['losses'][k] = v.unsqueeze(0)
        for k, v in return_dict['metrics'].items():
            if len(v.shape) == 0:
                return_dict['metrics'][k] = v.unsqueeze(0)

        return return_dict
Ejemplo n.º 4
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].to(
            torch.bool
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].to(
            torch.bool)  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary
        opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
                                              rot_angle.cpu(),
                                              is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)

        # Replace extreme betas with zero betas
        opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.
        # Replace the optimized parameters with the ground truth parameters, if available
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints

        input_batch['verts'] = opt_vertices

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        # get fitted smpl parameters as pseudo ground truth
        valid_fit = self.fits_dict.get_vaild_state(
            dataset_name, indices.cpu()).to(torch.bool).to(self.device)

        try:
            valid_fit = valid_fit | has_smpl
        except RuntimeError:
            valid_fit = (valid_fit.byte() | has_smpl.byte()).to(torch.bool)

        # Render Dense Correspondences
        if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON:
            gt_cam_t_nr = opt_cam_t.detach().clone()
            gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device)
            gt_camera[:, 1:] = gt_cam_t_nr[:, :2]
            gt_camera[:, 0] = (2. * self.focal_length /
                               self.options.img_res) / gt_cam_t_nr[:, 2]
            iuv_image_gt = torch.zeros(
                (batch_size, 3, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE,
                 cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)).to(self.device)
            if torch.sum(valid_fit.float()) > 0:
                iuv_image_gt[valid_fit] = self.iuv_maker.verts2iuvimg(
                    opt_vertices[valid_fit],
                    cam=gt_camera[valid_fit])  # [B, 3, 56, 56]
            input_batch['iuv_image_gt'] = iuv_image_gt

            uvia_list = iuv_img2map(iuv_image_gt)

        # Feed images in the network to predict camera and SMPL parameters
        if self.options.regressor == 'hmr':
            pred_rotmat, pred_betas, pred_camera = self.model(images)
            # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
        elif self.options.regressor == 'pymaf_net':
            preds_dict, _ = self.model(images)

        output = preds_dict
        loss_dict = {}

        if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON:
            dp_out = preds_dict['dp_out']
            for i in range(len(dp_out)):
                r_i = i - len(dp_out)

                u_pred, v_pred, index_pred, ann_pred = dp_out[r_i][
                    'predict_u'], dp_out[r_i]['predict_v'], dp_out[r_i][
                        'predict_uv_index'], dp_out[r_i]['predict_ann_index']
                if index_pred.shape[-1] == iuv_image_gt.shape[-1]:
                    uvia_list_i = uvia_list
                else:
                    iuv_image_gt_i = F.interpolate(iuv_image_gt,
                                                   u_pred.shape[-1],
                                                   mode='nearest')
                    uvia_list_i = iuv_img2map(iuv_image_gt_i)

                loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses(
                    u_pred, v_pred, index_pred, ann_pred, uvia_list_i,
                    valid_fit)
                loss_dict[f'loss_U{r_i}'] = loss_U
                loss_dict[f'loss_V{r_i}'] = loss_V
                loss_dict[f'loss_IndexUV{r_i}'] = loss_IndexUV
                loss_dict[f'loss_segAnn{r_i}'] = loss_segAnn

        len_loop = len(preds_dict['smpl_out']
                       ) if self.options.regressor == 'pymaf_net' else 1

        for l_i in range(len_loop):

            if self.options.regressor == 'pymaf_net':
                if l_i == 0:
                    # initial parameters (mean poses)
                    continue
                pred_rotmat = preds_dict['smpl_out'][l_i]['rotmat']
                pred_betas = preds_dict['smpl_out'][l_i]['theta'][:, 3:13]
                pred_camera = preds_dict['smpl_out'][l_i]['theta'][:, :3]

            pred_output = self.smpl(betas=pred_betas,
                                    body_pose=pred_rotmat[:, 1:],
                                    global_orient=pred_rotmat[:,
                                                              0].unsqueeze(1),
                                    pose2rot=False)
            pred_vertices = pred_output.vertices
            pred_joints = pred_output.joints

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
                (self.options.img_res * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(batch_size, 2, device=self.device)
            pred_keypoints_2d = perspective_projection(
                pred_joints,
                rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1),
                translation=pred_cam_t,
                focal_length=self.focal_length,
                camera_center=camera_center)
            # Normalize keypoints to [-1,1]
            pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

            # Compute loss on SMPL parameters
            loss_regr_pose, loss_regr_betas = self.smpl_losses(
                pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)
            loss_regr_pose *= cfg.LOSS.POSE_W
            loss_regr_betas *= cfg.LOSS.SHAPE_W
            loss_dict['loss_regr_pose_{}'.format(l_i)] = loss_regr_pose
            loss_dict['loss_regr_betas_{}'.format(l_i)] = loss_regr_betas

            # Compute 2D reprojection loss for the keypoints
            if cfg.LOSS.KP_2D_W > 0:
                loss_keypoints = self.keypoint_loss(
                    pred_keypoints_2d, gt_keypoints_2d,
                    self.options.openpose_train_weight,
                    self.options.gt_train_weight) * cfg.LOSS.KP_2D_W
                loss_dict['loss_keypoints_{}'.format(l_i)] = loss_keypoints

            # Compute 3D keypoint loss
            loss_keypoints_3d = self.keypoint_3d_loss(
                pred_joints, gt_joints, has_pose_3d) * cfg.LOSS.KP_3D_W
            loss_dict['loss_keypoints_3d_{}'.format(l_i)] = loss_keypoints_3d

            # Per-vertex loss for the shape
            if cfg.LOSS.VERT_W > 0:
                loss_shape = self.shape_loss(pred_vertices, opt_vertices,
                                             valid_fit) * cfg.LOSS.VERT_W
                loss_dict['loss_shape_{}'.format(l_i)] = loss_shape

            # Camera
            # force the network to predict positive depth values
            loss_cam = ((torch.exp(-pred_camera[:, 0] * 10))**2).mean()
            loss_dict['loss_cam_{}'.format(l_i)] = loss_cam

        for key in loss_dict:
            if len(loss_dict[key].shape) > 0:
                loss_dict[key] = loss_dict[key][0]

        # Compute total loss
        loss = torch.stack(list(loss_dict.values())).sum()

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output.update({
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        })
        loss_dict['loss'] = loss.detach().item()

        if self.step_count % 100 == 0:
            if self.options.multiprocessing_distributed:
                for loss_name, val in loss_dict.items():
                    val = val / self.options.world_size
                    if not torch.is_tensor(val):
                        val = torch.Tensor([val]).to(self.device)
                    dist.all_reduce(val)
                    loss_dict[loss_name] = val
            if self.options.rank == 0:
                for loss_name, val in loss_dict.items():
                    self.summary_writer.add_scalar(
                        'losses/{}'.format(loss_name), val, self.step_count)

        return {'preds': output, 'losses': loss_dict}
Ejemplo n.º 5
0
    def forward(self, data, iuv_image_gt=None, smpl_kps_gt=None, kps3d_gt=None, uvia_dp_gt=None, has_iuv=None, has_dp=None):
        return_dict = {}
        return_dict['losses'] = {}
        return_dict['metrics'] = {}
        return_dict['visualization'] = {}

        if cfg.DANET.INPUT_MODE in ['iuv_gt']:
            uvia_list = iuv_img2map(iuv_image_gt)
            stn_centers_target = smpl_kps_gt[:, :, :2].contiguous()
            if self.training and cfg.DANET.STN_CENTER_JITTER > 0:
                stn_centers_target = stn_centers_target + cfg.DANET.STN_CENTER_JITTER * (
                            torch.rand(stn_centers_target.size()).cuda(stn_centers_target.device) - 0.5)

            thetas, scales = self.affine_para(stn_centers_target)

            part_map_size = iuv_image_gt.size(-1)
            pred_gt_ratio = float(part_map_size) / uvia_list[0].size(-1)
            iuv_resized = [F.interpolate(uvia_list[i], scale_factor=pred_gt_ratio, mode='nearest') for i in
                           range(3)]
            iuv_simplified = self.part_iuv_simp(iuv_resized)
            part_iuv_gt = []
            for i in range(len(iuv_simplified)):
                part_iuv_i = iuv_simplified[i]
                grid = F.affine_grid(thetas[i], part_iuv_i.size())
                part_iuv_i = F.grid_sample(part_iuv_i, grid)
                part_iuv_i = part_iuv_i.view(-1, 3, len(self.dp2smpl_mapping[i]) + 1, part_map_size, part_map_size)
                part_iuv_gt.append(part_iuv_i)

            # (bs, 24, 3, 7, 56, 56)
            return_dict['part_iuv_gt'] = torch.stack(part_iuv_gt, dim=1)

            return return_dict

        uv_est_dic = self.iuv_est(data)
        u_pred, v_pred, index_pred, ann_pred = uv_est_dic['predict_u'], uv_est_dic['predict_v'], uv_est_dic['predict_uv_index'], uv_est_dic['predict_ann_index']

        if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']:
            return_dict['global_featmaps'] = uv_est_dic['xd']

        if self.training and iuv_image_gt is not None:
            uvia_list = iuv_img2map(iuv_image_gt)
            loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses(u_pred, v_pred, index_pred, ann_pred,
                                                                            uvia_list, has_iuv)
            return_dict['losses']['loss_U'] = loss_U
            return_dict['losses']['loss_V'] = loss_V
            return_dict['losses']['loss_IndexUV'] = loss_IndexUV
            return_dict['losses']['loss_segAnn'] = loss_segAnn

        if self.training and uvia_dp_gt is not None:
            if torch.sum(has_dp) > 0:
                dp_on = (has_dp == 1)
                uvia_dp_gt_ = {k: v[dp_on] if isinstance(v, torch.Tensor) else v for k, v in uvia_dp_gt.items()}
                loss_Udp, loss_Vdp, loss_IndexUVdp, loss_segAnndp = self.dp_uvia_losses(u_pred[dp_on], v_pred[dp_on],
                                                                                        index_pred[dp_on],
                                                                                        ann_pred[dp_on], **uvia_dp_gt_)
                return_dict['losses']['loss_Udp'] = loss_Udp
                return_dict['losses']['loss_Vdp'] = loss_Vdp
                return_dict['losses']['loss_IndexUVdp'] = loss_IndexUVdp
                return_dict['losses']['loss_segAnndp'] = loss_segAnndp
            else:
                return_dict['losses']['loss_Udp'] = torch.zeros(1).to(data.device)
                return_dict['losses']['loss_Vdp'] = torch.zeros(1).to(data.device)
                return_dict['losses']['loss_IndexUVdp'] = torch.zeros(1).to(data.device)
                return_dict['losses']['loss_segAnndp'] = torch.zeros(1).to(data.device)

        return_dict['uvia_pred'] = [u_pred, v_pred, index_pred, ann_pred]

        if cfg.DANET.DECOMPOSED:

            u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred)

            partial_decon_feat = uv_est_dic['xd']

            skps_hm_pred = uv_est_dic['predict_hm']

            smpl_kps_hm_size = skps_hm_pred.size(-1)

            return_dict['skps_hm_pred'] = skps_hm_pred.detach()

            stn_centers = softmax_integral_tensor(10 * skps_hm_pred, skps_hm_pred.size(1), skps_hm_pred.size(-2),
                                                  skps_hm_pred.size(-1))
            stn_centers /= 0.5 * smpl_kps_hm_size
            stn_centers -= 1

            if self.training and smpl_kps_gt is not None:
                if cfg.DANET.STN_HM_WEIGHTS > 0:
                    smpl_kps_norm = smpl_kps_gt.detach().clone()
                    # [-1, 1]  ->  [0, 1]
                    smpl_kps_norm[:, :, :2] *= 0.5
                    smpl_kps_norm[:, :, :2] += 0.5
                    smpl_kps_norm = smpl_kps_norm.view(smpl_kps_norm.size(0) * smpl_kps_norm.size(1), -1)[:, :2]
                    skps_hm_gt, _ = generate_heatmap(smpl_kps_norm, heatmap_size=cfg.DANET.HEATMAP_SIZE)
                    skps_hm_gt = skps_hm_gt.view(smpl_kps_gt.size(0), smpl_kps_gt.size(1), cfg.BODY_UV_RCNN.HEATMAP_SIZE,
                                                cfg.DANET.HEATMAP_SIZE)
                    skps_hm_gt = skps_hm_gt.detach()
                    return_dict['skps_hm_gt'] = skps_hm_gt.detach()

                    loss_stnhm = F.smooth_l1_loss(skps_hm_pred, skps_hm_gt, size_average=True)  # / smpl_kps_gt.size(0)
                    loss_stnhm *= cfg.DANET.STN_HM_WEIGHTS
                    return_dict['losses']['loss_stnhm'] = loss_stnhm

                if cfg.DANET.STN_KPS_WEIGHTS > 0:
                    if smpl_kps_gt.shape[-1] == 3:
                        loss_roi = 0
                        for w in torch.unique(smpl_kps_gt[:, :, 2]):
                            if w == 0:
                                continue
                            kps_w_idx = smpl_kps_gt[:, :, 2] == w
                            # stn_centers_target = smpl_kps_gt[:, :, :2][kps_w1_idx]
                            loss_roi += F.smooth_l1_loss(stn_centers[kps_w_idx], smpl_kps_gt[:, :, :2][kps_w_idx], size_average=False) * w
                        loss_roi /= smpl_kps_gt.size(0)

                        loss_roi *= cfg.DANET.STN_KPS_WEIGHTS
                        return_dict['losses']['loss_roi'] = loss_roi

                if cfg.DANET.STN_CENTER_JITTER > 0:
                    stn_centers = stn_centers + cfg.DANET.STN_CENTER_JITTER * (torch.rand(stn_centers.size()).cuda(stn_centers.device) - 0.5)

            if cfg.DANET.STN_PART_VIS_SCORE > 0:
                part_hidden_score = []
                for i in range(24):
                    score_map = torch.max(index_pred_cl[:, self.smpl2dp_part[i]], dim=1)[0].detach()
                    score_i = F.grid_sample(score_map.unsqueeze(1), stn_centers[:, i].unsqueeze(1).unsqueeze(1)).detach()
                    part_hidden_score.append(score_i.squeeze(-1).squeeze(-1).squeeze(-1))

                part_hidden_score = torch.stack(part_hidden_score)
                part_hidden_score = part_hidden_score < cfg.DANET.STN_PART_VIS_SCORE

            else:
                part_hidden_score = None

            maps_transformed = []

            thetas, scales = self.affine_para(stn_centers, part_hidden_score)

            for i in range(24):
                theta_i = thetas[i]
                scale_i = scales[i]

                grid = F.affine_grid(theta_i.detach(), partial_decon_feat.size())
                maps_transformed_i = F.grid_sample(partial_decon_feat, grid)

                maps_transformed.append(maps_transformed_i)

            return_dict['stn_kps_pred'] = stn_centers.detach()

            part_maps = torch.cat(maps_transformed, dim=1)

            part_iuv_pred = self.iuv_est.final_pred.predict_partial_iuv(part_maps)
            part_map_size = part_iuv_pred.size(-1)
            # (bs, 24, 3, 7, 56, 56)
            part_iuv_pred = part_iuv_pred.view(part_iuv_pred.size(0), len(self.dp2smpl_mapping), 3, -1,
                                               part_map_size,
                                               part_map_size)

            if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']:
                return_dict['part_featmaps'] = part_maps.view(part_maps.size(0), 24, -1, part_maps.size(-2), part_maps.size(-1))

            ## partial uv losses
            if self.training and iuv_image_gt is not None:
                pred_gt_ratio = float(part_map_size) / uvia_list[0].size(-1)
                iuv_resized = [F.interpolate(uvia_list[i], scale_factor=pred_gt_ratio, mode='nearest') for i in
                               range(3)]
                iuv_simplified = self.part_iuv_simp(iuv_resized)
                part_iuv_gt = []
                for i in range(len(iuv_simplified)):
                    part_iuv_i = iuv_simplified[i]
                    grid = F.affine_grid(thetas[i].detach(), part_iuv_i.size())
                    part_iuv_i = F.grid_sample(part_iuv_i, grid)
                    part_iuv_i = part_iuv_i.view(-1, 3, len(self.dp2smpl_mapping[i]) + 1, part_map_size, part_map_size)
                    part_iuv_gt.append(part_iuv_i)

                return_dict['part_iuv_gt'] = torch.stack(part_iuv_gt, dim=1)

                loss_p_U, loss_p_V, loss_p_IndexUV = None, None, None
                for i in range(len(part_iuv_gt)):
                    part_uvia_list = [part_iuv_gt[i][:, iuv] for iuv in range(3)]
                    part_uvia_list.append(None)

                    p_iuv_pred_i = [part_iuv_pred[:, i, iuv] for iuv in range(3)]

                    loss_p_U_i, loss_p_V_i, loss_p_IndexUV_i, _ = self.body_uv_losses(p_iuv_pred_i[0], p_iuv_pred_i[1],
                                                                                      p_iuv_pred_i[2], None,
                                                                                      part_uvia_list, has_iuv)

                    if i == 0:
                        loss_p_U, loss_p_V, loss_p_IndexUV = loss_p_U_i, loss_p_V_i, loss_p_IndexUV_i
                    else:
                        loss_p_U += loss_p_U_i
                        loss_p_V += loss_p_V_i
                        loss_p_IndexUV += loss_p_IndexUV_i

                loss_p_U /= 24.
                loss_p_V /= 24.
                loss_p_IndexUV /= 24.

                return_dict['losses']['loss_pU'] = loss_p_U
                return_dict['losses']['loss_pV'] = loss_p_V
                return_dict['losses']['loss_pIndexUV'] = loss_p_IndexUV

            return_dict['part_iuv_pred'] = part_iuv_pred

        return return_dict