def forward(self, data): return_dict = {} return_dict['visualization'] = {} return_dict['losses'] = {} if cfg.DANET.INPUT_MODE == 'rgb': para, _ = self.Conv_Body(data) elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']: if data.size(1) == 3: Umap, Vmap, Imap, _ = iuv_img2map(data) iuv_map = torch.cat([Umap, Vmap, Imap], dim=1) else: iuv_map = data para, _ = self.Conv_Body(iuv_map) elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']: para, _ = self.Conv_Body( torch.cat([data['iuv'], data['feat']], dim=1)) elif cfg.DANET.INPUT_MODE == 'feat': para, _ = self.Conv_Body(data['feat']) elif cfg.DANET.INPUT_MODE == 'seg': para, _ = self.Conv_Body(data['index']) return_dict['para'] = para return return_dict
def infer_net(self, image): """For inference""" return_dict = {} return_dict['visualization'] = {} if cfg.DANET.INPUT_MODE in ['iuv_gt']: if cfg.DANET.DECOMPOSED: uv_return_dict = self.img2iuv(image[0], iuv_image_gt=image[1], smpl_kps_gt=image[2]) u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image[1]) else: uv_return_dict = {} u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image) elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']: uv_return_dict = self.img2iuv(image[0]) u_pred, v_pred, index_pred, ann_pred = iuv_img2map(image[1]) else: uv_return_dict = self.img2iuv(image) u_pred, v_pred, index_pred, ann_pred = iuvmap_clean(*uv_return_dict['uvia_pred']) return_dict['visualization']['iuv_pred'] = [u_pred, v_pred, index_pred, ann_pred] if 'part_iuv_pred' in uv_return_dict: return_dict['visualization']['part_iuv_pred'] = uv_return_dict['part_iuv_pred'] iuv_map = torch.cat([u_pred, v_pred, index_pred], dim=1) if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict: part_iuv_map = uv_return_dict['part_iuv_gt'] part_index_map = part_iuv_map[:, :, 2] elif 'part_iuv_pred' in uv_return_dict: part_iuv_pred = uv_return_dict['part_iuv_pred'] part_iuv_map = [] for p_ind in range(part_iuv_pred.size(1)): p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)] p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred) p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1) part_iuv_map.append(p_iuv_map) part_iuv_map = torch.stack(part_iuv_map, dim=1) part_index_map = part_iuv_map[:, :, 2].detach() else: part_iuv_map = None part_index_map = None if 'part_featmaps' in uv_return_dict: part_feat_map = uv_return_dict['part_featmaps'] else: part_feat_map = None if cfg.DANET.INPUT_MODE == 'feat': smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'feat': uv_return_dict['global_featmaps']}, 'part_iuv_map': {'pfeat': part_feat_map} }) elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']: smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']}, 'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map} }) elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']: smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': iuv_map, 'part_iuv_map': part_iuv_map }) elif cfg.DANET.INPUT_MODE == 'seg': smpl_return_dict = self.iuv2smpl.smpl_infer_net({'iuv_map': {'index': index_pred.detach()}, 'part_iuv_map': {'pindex': part_index_map} }) return_dict['para'] = smpl_return_dict['para'] for k, v in smpl_return_dict['visualization'].items(): return_dict['visualization'][k] = v return return_dict
def _forward(self, in_dict): if type(in_dict) is not dict: in_dict = {'img': in_dict, 'pretrain_mode': False, 'vis_on': False, 'dataset_name': ''} image = in_dict['img'] gt_pose = in_dict['opt_pose'] if 'opt_pose' in in_dict else None # SMPL pose parameters gt_betas = in_dict['opt_betas'] if 'opt_betas' in in_dict else None # SMPL beta parameters target_kps = in_dict['target_kps'] if 'target_kps' in in_dict else None target_kps3d = in_dict['target_kps3d'] if 'target_kps3d' in in_dict else None has_iuv = in_dict['has_iuv'].byte() if 'has_iuv' in in_dict else None has_dp = in_dict['has_dp'].byte() if 'has_dp' in in_dict else None has_kp3d = in_dict['has_pose_3d'].byte() if 'has_pose_3d' in in_dict else None # flag that indicates whether 3D pose is valid target_smpl_kps = in_dict['target_smpl_kps'] if 'target_smpl_kps' in in_dict else None target_verts = in_dict['target_verts'] if 'target_verts' in in_dict else None valid_fit = in_dict['valid_fit'] if 'valid_fit' in in_dict else None batch_size = image.shape[0] if gt_pose is not None: gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24 * 3 * 3) target_cam = in_dict['target_cam'] target = torch.cat([target_cam, gt_betas, gt_rotmat], dim=1) uv_image_gt = torch.zeros((batch_size, 3, cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)).to(image.device) if torch.sum(has_iuv) > 0: uv_image_gt[has_iuv] = self.iuv2smpl.verts2uvimg(target_verts[has_iuv], cam=target_cam[has_iuv]) # [B, 3, 56, 56] else: target = None # target_iuv_dp = in_dict['target_iuv_dp'] if 'target_iuv_dp' in in_dict else None target_iuv_dp = in_dict['dp_dict'] if 'dp_dict' in in_dict else None if 'target_kps_coco' in in_dict: target_kps = in_dict['target_kps_coco'] return_dict = {} return_dict['losses'] = {} return_dict['metrics'] = {} return_dict['visualization'] = {} return_dict['prediction'] = {} if cfg.DANET.INPUT_MODE in ['iuv_gt']: if cfg.DANET.DECOMPOSED: uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp) uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt) else: uv_return_dict = {} uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt) elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']: uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp) uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt) elif cfg.DANET.INPUT_MODE in ['feat']: uv_return_dict = self.img2iuv(image, None, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp) else: uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp, has_iuv=has_iuv, has_dp=has_dp) u_pred, v_pred, index_pred, ann_pred = uv_return_dict['uvia_pred'] if self.training and cfg.DANET.PART_IUV_ZERO > 0: zero_idxs = [] for bs in range(u_pred.shape[0]): zero_idxs.append([int(i) + 1 for i in torch.nonzero(torch.rand(24) < cfg.DANET.PART_IUV_ZERO)]) if self.training and cfg.DANET.PART_IUV_ZERO > 0: for bs in range(len(zero_idxs)): u_pred[bs, zero_idxs[bs]] *= 0 v_pred[bs, zero_idxs[bs]] *= 0 index_pred[bs, zero_idxs[bs]] *= 0 u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred) iuv_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()] return_dict['visualization']['iuv_pred'] = iuv_pred_clean if in_dict['vis_on']: uvi_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()] return_dict['visualization']['pred_uv'] = iuv_map2img(*uvi_pred_clean) return_dict['visualization']['gt_uv'] = uv_image_gt if 'stn_kps_pred' in uv_return_dict: return_dict['visualization']['stn_kps_pred'] = uv_return_dict['stn_kps_pred'] # index_pred_cl shape: 2, 25, 56, 56 return_dict['visualization']['index_sum'] = [torch.sum(index_pred_cl[:, 1:].detach()).unsqueeze(0), np.prod(index_pred_cl[:, 0].shape)] for key in ['skps_hm_pred', 'skps_hm_gt']: if key in uv_return_dict: return_dict['visualization'][key] = torch.max(uv_return_dict[key], dim=1)[0].unsqueeze(1) return_dict['visualization'][key][return_dict['visualization'][key] > 1] = 1. skps_hm_vis = uv_return_dict[key] skps_hm_vis = skps_hm_vis.reshape((skps_hm_vis.shape[0], skps_hm_vis.shape[1], -1)) skps_hm_vis = F.softmax(skps_hm_vis, 2) skps_hm_vis = skps_hm_vis.reshape(skps_hm_vis.shape[0], skps_hm_vis.shape[1], cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE) return_dict['visualization'][key + '_soft'] = torch.sum(skps_hm_vis, dim=1).unsqueeze(1) # for key in ['part_uvi_pred', 'part_uvi_gt']: for key in ['part_uvi_gt']: if key in uv_return_dict: part_uvi_pred_vis = uv_return_dict[key][0] p_uvi_vis = [] for i in range(part_uvi_pred_vis.size(0)): p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)] if p_u_vis.size(1) == 25: p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach()) else: p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i]) # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i]) p_uvi_vis.append(p_uvi_vis_i) return_dict['visualization'][key] = torch.cat(p_uvi_vis, dim=0) if not in_dict['pretrain_mode']: iuv_map = torch.cat([u_pred_cl, v_pred_cl, index_pred_cl], dim=1) if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict: part_iuv_map = uv_return_dict['part_iuv_gt'] if self.training and cfg.DANET.PART_IUV_ZERO > 0: for bs in range(len(zero_idxs)): zero_channel = [] for zero_i in zero_idxs[bs]: zero_channel.extend( [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in enumerate(mapping) if map_idx == zero_i]) zero_dp_i = [iterm[0] for iterm in zero_channel] zero_p_i = [iterm[1] for iterm in zero_channel] part_iuv_map[bs, zero_dp_i, :, zero_p_i] *= 0 part_index_map = part_iuv_map[:, :, 2] elif 'part_iuv_pred' in uv_return_dict: part_iuv_pred = uv_return_dict['part_iuv_pred'] if self.training and cfg.DANET.PART_IUV_ZERO > 0: for bs in range(len(zero_idxs)): zero_channel = [] for zero_i in zero_idxs[bs]: zero_channel.extend( [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in enumerate(mapping) if map_idx == zero_i]) zero_dp_i = [iterm[0] for iterm in zero_channel] zero_p_i = [iterm[1] for iterm in zero_channel] part_iuv_pred[bs, zero_dp_i, :, zero_p_i] *= 0 part_iuv_map = [] for p_ind in range(part_iuv_pred.size(1)): p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)] p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred) p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1) part_iuv_map.append(p_iuv_map) part_iuv_map = torch.stack(part_iuv_map, dim=1) part_index_map = part_iuv_map[:, :, 2] else: part_iuv_map = None part_index_map = None return_dict['visualization']['part_iuv_pred'] = part_iuv_map if 'part_featmaps' in uv_return_dict: part_feat_map = uv_return_dict['part_featmaps'] else: part_feat_map = None if cfg.DANET.INPUT_MODE == 'feat': smpl_return_dict = self.iuv2smpl({'iuv_map': {'feat': uv_return_dict['global_featmaps']}, 'part_iuv_map': {'pfeat': part_feat_map}, 'target': target, 'target_kps': target_kps, 'target_verts': target_verts, 'target_kps3d': target_kps3d, 'has_kp3d': has_kp3d }) elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']: smpl_return_dict = self.iuv2smpl({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']}, 'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map}, 'target': target, 'target_kps': target_kps, 'target_verts': target_verts, 'target_kps3d': target_kps3d, 'has_kp3d': has_kp3d }) elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']: smpl_return_dict = self.iuv2smpl({'iuv_map': iuv_map, 'part_iuv_map': part_iuv_map, 'target': target, 'target_kps': target_kps, 'target_verts': target_verts, 'target_kps3d': target_kps3d, 'has_kp3d': has_kp3d, 'has_smpl': valid_fit }) elif cfg.DANET.INPUT_MODE == 'seg': # REMOVE _.detach smpl_return_dict = self.iuv2smpl({'iuv_map': {'index': index_pred_cl}, 'part_iuv_map': {'pindex': part_index_map}, 'target': target, 'target_kps': target_kps, 'target_verts': target_verts, 'target_kps3d': target_kps3d, 'has_kp3d': has_kp3d }) if in_dict['vis_on'] and part_index_map is not None: # part_index_map: 2, 24, 7, 56, 56 return_dict['visualization']['p_index_sum'] = [torch.sum(part_index_map[:, :, 1:].detach()).unsqueeze(0), np.prod(part_index_map[:, :, 0].shape)] if in_dict['vis_on'] and part_iuv_map is not None: part_uvi_pred_vis = part_iuv_map[0] p_uvi_vis = [] for i in range(part_uvi_pred_vis.size(0)): p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)] if p_u_vis.size(1) == 25: p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach()) else: p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i]) # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i]) p_uvi_vis.append(p_uvi_vis_i) return_dict['visualization']['part_uvi_pred'] = torch.cat(p_uvi_vis, dim=0) for key_name in ['losses', 'metrics', 'visualization', 'prediction']: if key_name in uv_return_dict: return_dict[key_name].update(uv_return_dict[key_name]) if not in_dict['pretrain_mode']: return_dict[key_name].update(smpl_return_dict[key_name]) # pytorch0.4 bug on gathering scalar(0-dim) tensors for k, v in return_dict['losses'].items(): if len(v.shape) == 0: return_dict['losses'][k] = v.unsqueeze(0) for k, v in return_dict['metrics'].items(): if len(v.shape) == 0: return_dict['metrics'][k] = v.unsqueeze(0) return return_dict
def train_step(self, input_batch): self.model.train() # Get data from the batch images = input_batch['img'] # input image gt_keypoints_2d = input_batch['keypoints'] # 2D keypoints gt_pose = input_batch['pose'] # SMPL pose parameters gt_betas = input_batch['betas'] # SMPL beta parameters gt_joints = input_batch['pose_3d'] # 3D pose has_smpl = input_batch['has_smpl'].to( torch.bool ) # flag that indicates whether SMPL parameters are valid has_pose_3d = input_batch['has_pose_3d'].to( torch.bool) # flag that indicates whether 3D pose is valid is_flipped = input_batch[ 'is_flipped'] # flag that indicates whether image was flipped during data augmentation rot_angle = input_batch[ 'rot_angle'] # rotation angle used for data augmentation dataset_name = input_batch[ 'dataset_name'] # name of the dataset the image comes from indices = input_batch[ 'sample_index'] # index of example inside its dataset batch_size = images.shape[0] # Get GT vertices and model joints # Note that gt_model_joints is different from gt_joints as it comes from SMPL gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]) gt_model_joints = gt_out.joints gt_vertices = gt_out.vertices # Get current best fits from the dictionary opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), is_flipped.cpu())] opt_pose = opt_pose.to(self.device) opt_betas = opt_betas.to(self.device) # Replace extreme betas with zero betas opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. # Replace the optimized parameters with the ground truth parameters, if available opt_pose[has_smpl, :] = gt_pose[has_smpl, :] opt_betas[has_smpl, :] = gt_betas[has_smpl, :] opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:, 3:], global_orient=opt_pose[:, :3]) opt_vertices = opt_output.vertices opt_joints = opt_output.joints input_batch['verts'] = opt_vertices # De-normalize 2D keypoints from [-1,1] to pixel space gt_keypoints_2d_orig = gt_keypoints_2d.clone() gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * ( gt_keypoints_2d_orig[:, :, :-1] + 1) # Estimate camera translation given the model joints and 2D keypoints # by minimizing a weighted least squares loss gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) opt_cam_t = estimate_translation(opt_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res) # get fitted smpl parameters as pseudo ground truth valid_fit = self.fits_dict.get_vaild_state( dataset_name, indices.cpu()).to(torch.bool).to(self.device) try: valid_fit = valid_fit | has_smpl except RuntimeError: valid_fit = (valid_fit.byte() | has_smpl.byte()).to(torch.bool) # Render Dense Correspondences if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON: gt_cam_t_nr = opt_cam_t.detach().clone() gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device) gt_camera[:, 1:] = gt_cam_t_nr[:, :2] gt_camera[:, 0] = (2. * self.focal_length / self.options.img_res) / gt_cam_t_nr[:, 2] iuv_image_gt = torch.zeros( (batch_size, 3, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)).to(self.device) if torch.sum(valid_fit.float()) > 0: iuv_image_gt[valid_fit] = self.iuv_maker.verts2iuvimg( opt_vertices[valid_fit], cam=gt_camera[valid_fit]) # [B, 3, 56, 56] input_batch['iuv_image_gt'] = iuv_image_gt uvia_list = iuv_img2map(iuv_image_gt) # Feed images in the network to predict camera and SMPL parameters if self.options.regressor == 'hmr': pred_rotmat, pred_betas, pred_camera = self.model(images) # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3]) elif self.options.regressor == 'pymaf_net': preds_dict, _ = self.model(images) output = preds_dict loss_dict = {} if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON: dp_out = preds_dict['dp_out'] for i in range(len(dp_out)): r_i = i - len(dp_out) u_pred, v_pred, index_pred, ann_pred = dp_out[r_i][ 'predict_u'], dp_out[r_i]['predict_v'], dp_out[r_i][ 'predict_uv_index'], dp_out[r_i]['predict_ann_index'] if index_pred.shape[-1] == iuv_image_gt.shape[-1]: uvia_list_i = uvia_list else: iuv_image_gt_i = F.interpolate(iuv_image_gt, u_pred.shape[-1], mode='nearest') uvia_list_i = iuv_img2map(iuv_image_gt_i) loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses( u_pred, v_pred, index_pred, ann_pred, uvia_list_i, valid_fit) loss_dict[f'loss_U{r_i}'] = loss_U loss_dict[f'loss_V{r_i}'] = loss_V loss_dict[f'loss_IndexUV{r_i}'] = loss_IndexUV loss_dict[f'loss_segAnn{r_i}'] = loss_segAnn len_loop = len(preds_dict['smpl_out'] ) if self.options.regressor == 'pymaf_net' else 1 for l_i in range(len_loop): if self.options.regressor == 'pymaf_net': if l_i == 0: # initial parameters (mean poses) continue pred_rotmat = preds_dict['smpl_out'][l_i]['rotmat'] pred_betas = preds_dict['smpl_out'][l_i]['theta'][:, 3:13] pred_camera = preds_dict['smpl_out'][l_i]['theta'][:, :3] pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) pred_vertices = pred_output.vertices pred_joints = pred_output.joints # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size # This camera translation can be used in a full perspective projection pred_cam_t = torch.stack([ pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length / (self.options.img_res * pred_camera[:, 0] + 1e-9) ], dim=-1) camera_center = torch.zeros(batch_size, 2, device=self.device) pred_keypoints_2d = perspective_projection( pred_joints, rotation=torch.eye(3, device=self.device).unsqueeze(0).expand( batch_size, -1, -1), translation=pred_cam_t, focal_length=self.focal_length, camera_center=camera_center) # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.) # Compute loss on SMPL parameters loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit) loss_regr_pose *= cfg.LOSS.POSE_W loss_regr_betas *= cfg.LOSS.SHAPE_W loss_dict['loss_regr_pose_{}'.format(l_i)] = loss_regr_pose loss_dict['loss_regr_betas_{}'.format(l_i)] = loss_regr_betas # Compute 2D reprojection loss for the keypoints if cfg.LOSS.KP_2D_W > 0: loss_keypoints = self.keypoint_loss( pred_keypoints_2d, gt_keypoints_2d, self.options.openpose_train_weight, self.options.gt_train_weight) * cfg.LOSS.KP_2D_W loss_dict['loss_keypoints_{}'.format(l_i)] = loss_keypoints # Compute 3D keypoint loss loss_keypoints_3d = self.keypoint_3d_loss( pred_joints, gt_joints, has_pose_3d) * cfg.LOSS.KP_3D_W loss_dict['loss_keypoints_3d_{}'.format(l_i)] = loss_keypoints_3d # Per-vertex loss for the shape if cfg.LOSS.VERT_W > 0: loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit) * cfg.LOSS.VERT_W loss_dict['loss_shape_{}'.format(l_i)] = loss_shape # Camera # force the network to predict positive depth values loss_cam = ((torch.exp(-pred_camera[:, 0] * 10))**2).mean() loss_dict['loss_cam_{}'.format(l_i)] = loss_cam for key in loss_dict: if len(loss_dict[key].shape) > 0: loss_dict[key] = loss_dict[key][0] # Compute total loss loss = torch.stack(list(loss_dict.values())).sum() # Do backprop self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Pack output arguments for tensorboard logging output.update({ 'pred_vertices': pred_vertices.detach(), 'opt_vertices': opt_vertices, 'pred_cam_t': pred_cam_t.detach(), 'opt_cam_t': opt_cam_t }) loss_dict['loss'] = loss.detach().item() if self.step_count % 100 == 0: if self.options.multiprocessing_distributed: for loss_name, val in loss_dict.items(): val = val / self.options.world_size if not torch.is_tensor(val): val = torch.Tensor([val]).to(self.device) dist.all_reduce(val) loss_dict[loss_name] = val if self.options.rank == 0: for loss_name, val in loss_dict.items(): self.summary_writer.add_scalar( 'losses/{}'.format(loss_name), val, self.step_count) return {'preds': output, 'losses': loss_dict}
def forward(self, data, iuv_image_gt=None, smpl_kps_gt=None, kps3d_gt=None, uvia_dp_gt=None, has_iuv=None, has_dp=None): return_dict = {} return_dict['losses'] = {} return_dict['metrics'] = {} return_dict['visualization'] = {} if cfg.DANET.INPUT_MODE in ['iuv_gt']: uvia_list = iuv_img2map(iuv_image_gt) stn_centers_target = smpl_kps_gt[:, :, :2].contiguous() if self.training and cfg.DANET.STN_CENTER_JITTER > 0: stn_centers_target = stn_centers_target + cfg.DANET.STN_CENTER_JITTER * ( torch.rand(stn_centers_target.size()).cuda(stn_centers_target.device) - 0.5) thetas, scales = self.affine_para(stn_centers_target) part_map_size = iuv_image_gt.size(-1) pred_gt_ratio = float(part_map_size) / uvia_list[0].size(-1) iuv_resized = [F.interpolate(uvia_list[i], scale_factor=pred_gt_ratio, mode='nearest') for i in range(3)] iuv_simplified = self.part_iuv_simp(iuv_resized) part_iuv_gt = [] for i in range(len(iuv_simplified)): part_iuv_i = iuv_simplified[i] grid = F.affine_grid(thetas[i], part_iuv_i.size()) part_iuv_i = F.grid_sample(part_iuv_i, grid) part_iuv_i = part_iuv_i.view(-1, 3, len(self.dp2smpl_mapping[i]) + 1, part_map_size, part_map_size) part_iuv_gt.append(part_iuv_i) # (bs, 24, 3, 7, 56, 56) return_dict['part_iuv_gt'] = torch.stack(part_iuv_gt, dim=1) return return_dict uv_est_dic = self.iuv_est(data) u_pred, v_pred, index_pred, ann_pred = uv_est_dic['predict_u'], uv_est_dic['predict_v'], uv_est_dic['predict_uv_index'], uv_est_dic['predict_ann_index'] if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']: return_dict['global_featmaps'] = uv_est_dic['xd'] if self.training and iuv_image_gt is not None: uvia_list = iuv_img2map(iuv_image_gt) loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses(u_pred, v_pred, index_pred, ann_pred, uvia_list, has_iuv) return_dict['losses']['loss_U'] = loss_U return_dict['losses']['loss_V'] = loss_V return_dict['losses']['loss_IndexUV'] = loss_IndexUV return_dict['losses']['loss_segAnn'] = loss_segAnn if self.training and uvia_dp_gt is not None: if torch.sum(has_dp) > 0: dp_on = (has_dp == 1) uvia_dp_gt_ = {k: v[dp_on] if isinstance(v, torch.Tensor) else v for k, v in uvia_dp_gt.items()} loss_Udp, loss_Vdp, loss_IndexUVdp, loss_segAnndp = self.dp_uvia_losses(u_pred[dp_on], v_pred[dp_on], index_pred[dp_on], ann_pred[dp_on], **uvia_dp_gt_) return_dict['losses']['loss_Udp'] = loss_Udp return_dict['losses']['loss_Vdp'] = loss_Vdp return_dict['losses']['loss_IndexUVdp'] = loss_IndexUVdp return_dict['losses']['loss_segAnndp'] = loss_segAnndp else: return_dict['losses']['loss_Udp'] = torch.zeros(1).to(data.device) return_dict['losses']['loss_Vdp'] = torch.zeros(1).to(data.device) return_dict['losses']['loss_IndexUVdp'] = torch.zeros(1).to(data.device) return_dict['losses']['loss_segAnndp'] = torch.zeros(1).to(data.device) return_dict['uvia_pred'] = [u_pred, v_pred, index_pred, ann_pred] if cfg.DANET.DECOMPOSED: u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred) partial_decon_feat = uv_est_dic['xd'] skps_hm_pred = uv_est_dic['predict_hm'] smpl_kps_hm_size = skps_hm_pred.size(-1) return_dict['skps_hm_pred'] = skps_hm_pred.detach() stn_centers = softmax_integral_tensor(10 * skps_hm_pred, skps_hm_pred.size(1), skps_hm_pred.size(-2), skps_hm_pred.size(-1)) stn_centers /= 0.5 * smpl_kps_hm_size stn_centers -= 1 if self.training and smpl_kps_gt is not None: if cfg.DANET.STN_HM_WEIGHTS > 0: smpl_kps_norm = smpl_kps_gt.detach().clone() # [-1, 1] -> [0, 1] smpl_kps_norm[:, :, :2] *= 0.5 smpl_kps_norm[:, :, :2] += 0.5 smpl_kps_norm = smpl_kps_norm.view(smpl_kps_norm.size(0) * smpl_kps_norm.size(1), -1)[:, :2] skps_hm_gt, _ = generate_heatmap(smpl_kps_norm, heatmap_size=cfg.DANET.HEATMAP_SIZE) skps_hm_gt = skps_hm_gt.view(smpl_kps_gt.size(0), smpl_kps_gt.size(1), cfg.BODY_UV_RCNN.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE) skps_hm_gt = skps_hm_gt.detach() return_dict['skps_hm_gt'] = skps_hm_gt.detach() loss_stnhm = F.smooth_l1_loss(skps_hm_pred, skps_hm_gt, size_average=True) # / smpl_kps_gt.size(0) loss_stnhm *= cfg.DANET.STN_HM_WEIGHTS return_dict['losses']['loss_stnhm'] = loss_stnhm if cfg.DANET.STN_KPS_WEIGHTS > 0: if smpl_kps_gt.shape[-1] == 3: loss_roi = 0 for w in torch.unique(smpl_kps_gt[:, :, 2]): if w == 0: continue kps_w_idx = smpl_kps_gt[:, :, 2] == w # stn_centers_target = smpl_kps_gt[:, :, :2][kps_w1_idx] loss_roi += F.smooth_l1_loss(stn_centers[kps_w_idx], smpl_kps_gt[:, :, :2][kps_w_idx], size_average=False) * w loss_roi /= smpl_kps_gt.size(0) loss_roi *= cfg.DANET.STN_KPS_WEIGHTS return_dict['losses']['loss_roi'] = loss_roi if cfg.DANET.STN_CENTER_JITTER > 0: stn_centers = stn_centers + cfg.DANET.STN_CENTER_JITTER * (torch.rand(stn_centers.size()).cuda(stn_centers.device) - 0.5) if cfg.DANET.STN_PART_VIS_SCORE > 0: part_hidden_score = [] for i in range(24): score_map = torch.max(index_pred_cl[:, self.smpl2dp_part[i]], dim=1)[0].detach() score_i = F.grid_sample(score_map.unsqueeze(1), stn_centers[:, i].unsqueeze(1).unsqueeze(1)).detach() part_hidden_score.append(score_i.squeeze(-1).squeeze(-1).squeeze(-1)) part_hidden_score = torch.stack(part_hidden_score) part_hidden_score = part_hidden_score < cfg.DANET.STN_PART_VIS_SCORE else: part_hidden_score = None maps_transformed = [] thetas, scales = self.affine_para(stn_centers, part_hidden_score) for i in range(24): theta_i = thetas[i] scale_i = scales[i] grid = F.affine_grid(theta_i.detach(), partial_decon_feat.size()) maps_transformed_i = F.grid_sample(partial_decon_feat, grid) maps_transformed.append(maps_transformed_i) return_dict['stn_kps_pred'] = stn_centers.detach() part_maps = torch.cat(maps_transformed, dim=1) part_iuv_pred = self.iuv_est.final_pred.predict_partial_iuv(part_maps) part_map_size = part_iuv_pred.size(-1) # (bs, 24, 3, 7, 56, 56) part_iuv_pred = part_iuv_pred.view(part_iuv_pred.size(0), len(self.dp2smpl_mapping), 3, -1, part_map_size, part_map_size) if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']: return_dict['part_featmaps'] = part_maps.view(part_maps.size(0), 24, -1, part_maps.size(-2), part_maps.size(-1)) ## partial uv losses if self.training and iuv_image_gt is not None: pred_gt_ratio = float(part_map_size) / uvia_list[0].size(-1) iuv_resized = [F.interpolate(uvia_list[i], scale_factor=pred_gt_ratio, mode='nearest') for i in range(3)] iuv_simplified = self.part_iuv_simp(iuv_resized) part_iuv_gt = [] for i in range(len(iuv_simplified)): part_iuv_i = iuv_simplified[i] grid = F.affine_grid(thetas[i].detach(), part_iuv_i.size()) part_iuv_i = F.grid_sample(part_iuv_i, grid) part_iuv_i = part_iuv_i.view(-1, 3, len(self.dp2smpl_mapping[i]) + 1, part_map_size, part_map_size) part_iuv_gt.append(part_iuv_i) return_dict['part_iuv_gt'] = torch.stack(part_iuv_gt, dim=1) loss_p_U, loss_p_V, loss_p_IndexUV = None, None, None for i in range(len(part_iuv_gt)): part_uvia_list = [part_iuv_gt[i][:, iuv] for iuv in range(3)] part_uvia_list.append(None) p_iuv_pred_i = [part_iuv_pred[:, i, iuv] for iuv in range(3)] loss_p_U_i, loss_p_V_i, loss_p_IndexUV_i, _ = self.body_uv_losses(p_iuv_pred_i[0], p_iuv_pred_i[1], p_iuv_pred_i[2], None, part_uvia_list, has_iuv) if i == 0: loss_p_U, loss_p_V, loss_p_IndexUV = loss_p_U_i, loss_p_V_i, loss_p_IndexUV_i else: loss_p_U += loss_p_U_i loss_p_V += loss_p_V_i loss_p_IndexUV += loss_p_IndexUV_i loss_p_U /= 24. loss_p_V /= 24. loss_p_IndexUV /= 24. return_dict['losses']['loss_pU'] = loss_p_U return_dict['losses']['loss_pV'] = loss_p_V return_dict['losses']['loss_pIndexUV'] = loss_p_IndexUV return_dict['part_iuv_pred'] = part_iuv_pred return return_dict