Esempio n. 1
0
 def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas):
     pred_rotmat_valid = batch_rodrigues(pred_rotmat.reshape(
         -1, 3)).reshape(-1, 24, 3, 3)
     gt_rotmat_valid = batch_rodrigues(gt_pose.reshape(-1, 3)).reshape(
         -1, 24, 3, 3)
     pred_betas_valid = pred_betas
     gt_betas_valid = gt_betas
     if len(pred_rotmat_valid) > 0:
         loss_regr_pose = self.criterion_regr(pred_rotmat_valid,
                                              gt_rotmat_valid)
         loss_regr_betas = self.criterion_regr(pred_betas_valid,
                                               gt_betas_valid)
     else:
         loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
         loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
     return loss_regr_pose, loss_regr_betas
Esempio n. 2
0
 def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl):
     pred_rotmat_valid = pred_rotmat[has_smpl == 1]
     gt_rotmat_valid = batch_rodrigues(gt_pose.view(-1,3)).view(-1, 24, 3, 3)[has_smpl == 1]
     pred_betas_valid = pred_betas[has_smpl == 1]
     gt_betas_valid = gt_betas[has_smpl == 1]
     if len(pred_rotmat_valid) > 0:
         loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid)
         loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid)
     else:
         loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
         loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
     return loss_regr_pose, loss_regr_betas
Esempio n. 3
0
    def __call__(self, V, pose, bone, scale, to_rotmats=True):
        batch_size = len(V)
        device = pose.device
        V = F.pad(V.unsqueeze(-1), [0, 0, 0, 1], value=1)
        kin_tree = (scale * self.kin_tree) * bone[:, :, None, None]

        if to_rotmats:
            pose = batch_rodrigues(pose.view(-1, 3))
        pose = pose.view([batch_size, -1, 3, 3])
        T = torch.zeros([batch_size, self.n_joints, 4, 4]).float().to(device)
        T[:, :, -1, -1] = 1
        T[:, :, :3, :] = torch.cat([pose, kin_tree], dim=-1)
        T_rel = [T[:, 0]]
        for i in range(1, self.n_joints):
            T_rel.append(T_rel[self.parents[i]] @ T[:, i])
        T_rel = torch.stack(T_rel, dim=1)
        T_rel[:, :, :, [-1]] -= T_rel.clone() @ (self.h_joints * scale)
        T_ = self.weights @ T_rel.view(batch_size, self.n_joints, -1)
        T_ = T_.view(batch_size, -1, 4, 4)
        V = T_ @ V

        return V[:, :, :3, 0]
Esempio n. 4
0
File: boa.py Progetto: xjwxjw/BOA
    def adapt_on_labeled_data(self, learner, databatch):
        image = databatch['img'].squeeze(0)
        trg_s3d = databatch['pose_3d'].squeeze(0)
        trg_s2d = databatch['keypoints'].squeeze(0)
        trg_betas = databatch['betas'].squeeze(0)
        trg_pose = databatch['pose'].squeeze(0)
        losses_dict = {}
        pred_rotmat, pred_betas, pred_cam = learner(image)
        smpl_out = decode_smpl_params(pred_rotmat,
                                      pred_betas,
                                      pred_cam,
                                      neutral=True,
                                      pose2rot=False)
        pred_s3d = smpl_out['s3d']
        pred_vts = smpl_out['vts']

        s2d_loss = cal_s2d_loss(pred_s3d, trg_s2d, pred_cam)
        s3d_loss = cal_s3d_loss(pred_s3d, trg_s3d)
        trg_rotmat = batch_rodrigues(trg_pose.view(-1, 3)).view(-1, 24, 3, 3)
        loss_pose = F.mse_loss(pred_rotmat, trg_rotmat)
        loss_beta = F.mse_loss(pred_betas, trg_betas)
        loss = s3d_loss * 5 + s2d_loss * 5 + loss_pose * 1 + loss_beta * 0.001
        return loss, s3d_loss, s2d_loss, loss_pose, loss_beta
    def _forward(self, in_dict):

        if type(in_dict) is not dict:
            in_dict = {'img': in_dict, 'pretrain_mode': False, 'vis_on': False, 'dataset_name': ''}

        image = in_dict['img']
        gt_pose = in_dict['opt_pose'] if 'opt_pose' in in_dict else None  # SMPL pose parameters
        gt_betas = in_dict['opt_betas'] if 'opt_betas' in in_dict else None  # SMPL beta parameters
        target_kps = in_dict['target_kps'] if 'target_kps' in in_dict else None
        target_kps3d = in_dict['target_kps3d'] if 'target_kps3d' in in_dict else None
        has_iuv = in_dict['has_iuv'].byte() if 'has_iuv' in in_dict else None
        has_dp = in_dict['has_dp'].byte() if 'has_dp' in in_dict else None
        has_kp3d = in_dict['has_pose_3d'].byte() if 'has_pose_3d' in in_dict else None  # flag that indicates whether 3D pose is valid
        target_smpl_kps = in_dict['target_smpl_kps'] if 'target_smpl_kps' in in_dict else None
        target_verts = in_dict['target_verts'] if 'target_verts' in in_dict else None
        valid_fit = in_dict['valid_fit'] if 'valid_fit' in in_dict else None

        batch_size = image.shape[0]

        if gt_pose is not None:
            gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24 * 3 * 3)
            target_cam = in_dict['target_cam']
            target = torch.cat([target_cam, gt_betas, gt_rotmat], dim=1)
            uv_image_gt = torch.zeros((batch_size, 3, cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)).to(image.device)
            if torch.sum(has_iuv) > 0:
                uv_image_gt[has_iuv] = self.iuv2smpl.verts2uvimg(target_verts[has_iuv], cam=target_cam[has_iuv])  # [B, 3, 56, 56]
        else:
            target = None

        # target_iuv_dp = in_dict['target_iuv_dp'] if 'target_iuv_dp' in in_dict else None
        target_iuv_dp = in_dict['dp_dict'] if 'dp_dict' in in_dict else None

        if 'target_kps_coco' in in_dict:
            target_kps = in_dict['target_kps_coco']

        return_dict = {}
        return_dict['losses'] = {}
        return_dict['metrics'] = {}
        return_dict['visualization'] = {}
        return_dict['prediction'] = {}

        if cfg.DANET.INPUT_MODE in ['iuv_gt']:
            if cfg.DANET.DECOMPOSED:
                uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
            else:
                uv_return_dict = {}
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
            uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['feat']:
            uv_return_dict = self.img2iuv(image, None, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
        else:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp, has_iuv=has_iuv, has_dp=has_dp)

        u_pred, v_pred, index_pred, ann_pred = uv_return_dict['uvia_pred']
        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            zero_idxs = []
            for bs in range(u_pred.shape[0]):
                zero_idxs.append([int(i) + 1 for i in torch.nonzero(torch.rand(24) < cfg.DANET.PART_IUV_ZERO)])

        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            for bs in range(len(zero_idxs)):
                u_pred[bs, zero_idxs[bs]] *= 0
                v_pred[bs, zero_idxs[bs]] *= 0
                index_pred[bs, zero_idxs[bs]] *= 0

        u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred)

        iuv_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
        return_dict['visualization']['iuv_pred'] = iuv_pred_clean

        if in_dict['vis_on']:
            uvi_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
            return_dict['visualization']['pred_uv'] = iuv_map2img(*uvi_pred_clean)
            return_dict['visualization']['gt_uv'] = uv_image_gt
            if 'stn_kps_pred' in uv_return_dict:
                return_dict['visualization']['stn_kps_pred'] = uv_return_dict['stn_kps_pred']

            # index_pred_cl shape:  2, 25, 56, 56
            return_dict['visualization']['index_sum'] = [torch.sum(index_pred_cl[:, 1:].detach()).unsqueeze(0), np.prod(index_pred_cl[:, 0].shape)]

            for key in ['skps_hm_pred', 'skps_hm_gt']:
                if key in uv_return_dict:
                    return_dict['visualization'][key] = torch.max(uv_return_dict[key], dim=1)[0].unsqueeze(1)
                    return_dict['visualization'][key][return_dict['visualization'][key] > 1] = 1.
                    skps_hm_vis = uv_return_dict[key]
                    skps_hm_vis = skps_hm_vis.reshape((skps_hm_vis.shape[0], skps_hm_vis.shape[1], -1))
                    skps_hm_vis = F.softmax(skps_hm_vis, 2)
                    skps_hm_vis = skps_hm_vis.reshape(skps_hm_vis.shape[0], skps_hm_vis.shape[1],
                                                      cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)
                    return_dict['visualization'][key + '_soft'] = torch.sum(skps_hm_vis, dim=1).unsqueeze(1)
            # for key in ['part_uvi_pred', 'part_uvi_gt']:
            for key in ['part_uvi_gt']:
                if key in uv_return_dict:
                    part_uvi_pred_vis = uv_return_dict[key][0]
                    p_uvi_vis = []
                    for i in range(part_uvi_pred_vis.size(0)):
                        p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                        if p_u_vis.size(1) == 25:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                        else:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                         ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                        # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                        p_uvi_vis.append(p_uvi_vis_i)
                    return_dict['visualization'][key] = torch.cat(p_uvi_vis, dim=0)

        if not in_dict['pretrain_mode']:

            iuv_map = torch.cat([u_pred_cl, v_pred_cl, index_pred_cl], dim=1)

            if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict:
                part_iuv_map = uv_return_dict['part_iuv_gt']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_map[bs, zero_dp_i, :, zero_p_i] *= 0

                part_index_map = part_iuv_map[:, :, 2]
            elif 'part_iuv_pred' in uv_return_dict:
                part_iuv_pred = uv_return_dict['part_iuv_pred']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_pred[bs, zero_dp_i, :, zero_p_i] *= 0

                part_iuv_map = []
                for p_ind in range(part_iuv_pred.size(1)):
                    p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)]
                    p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred)
                    p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1)
                    part_iuv_map.append(p_iuv_map)
                part_iuv_map = torch.stack(part_iuv_map, dim=1)
                part_index_map = part_iuv_map[:, :, 2]

            else:
                part_iuv_map = None
                part_index_map = None

            return_dict['visualization']['part_iuv_pred'] = part_iuv_map

            if 'part_featmaps' in uv_return_dict:
                part_feat_map = uv_return_dict['part_featmaps']
            else:
                part_feat_map = None

            if cfg.DANET.INPUT_MODE == 'feat':
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': iuv_map,
                                                 'part_iuv_map': part_iuv_map,
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d,
                                                 'has_smpl': valid_fit
                                                  })
            elif cfg.DANET.INPUT_MODE == 'seg':
                # REMOVE _.detach
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'index': index_pred_cl},
                                                 'part_iuv_map': {'pindex': part_index_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })

            if in_dict['vis_on'] and part_index_map is not None:
                # part_index_map: 2, 24, 7, 56, 56
                return_dict['visualization']['p_index_sum'] = [torch.sum(part_index_map[:, :, 1:].detach()).unsqueeze(0),
                                                               np.prod(part_index_map[:, :, 0].shape)]

            if in_dict['vis_on'] and part_iuv_map is not None:
                part_uvi_pred_vis = part_iuv_map[0]
                p_uvi_vis = []
                for i in range(part_uvi_pred_vis.size(0)):
                    p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                    if p_u_vis.size(1) == 25:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                    else:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                     ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                    # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                    p_uvi_vis.append(p_uvi_vis_i)
                return_dict['visualization']['part_uvi_pred'] = torch.cat(p_uvi_vis, dim=0)

        for key_name in ['losses', 'metrics', 'visualization', 'prediction']:
            if key_name in uv_return_dict:
                return_dict[key_name].update(uv_return_dict[key_name])
            if not in_dict['pretrain_mode']:
                return_dict[key_name].update(smpl_return_dict[key_name])

        # pytorch0.4 bug on gathering scalar(0-dim) tensors
        for k, v in return_dict['losses'].items():
            if len(v.shape) == 0:
                return_dict['losses'][k] = v.unsqueeze(0)
        for k, v in return_dict['metrics'].items():
            if len(v.shape) == 0:
                return_dict['metrics'][k] = v.unsqueeze(0)

        return return_dict