def handle(self, data_batch, pred_results, use_gt=False): pred_vertices = pred_results['pred_vertices'].cpu() gt_keypoints_3d = data_batch['gt_kpts3d'].data[0][0].clone().repeat( [pred_vertices.shape[0], 1, 1]) gt_pelvis_smpl = gt_keypoints_3d[:, [14], :-1].clone() gt_keypoints_3d = gt_keypoints_3d[:, J24_TO_J14, :-1].clone() gt_keypoints_3d = gt_keypoints_3d - gt_pelvis_smpl J_regressor_batch = self.J_regressor[None, :].expand( pred_vertices.shape[0], -1, -1).to(pred_vertices.device) # Get 14 predicted joints from the SMPL mesh pred_keypoints_3d_smpl = torch.matmul(J_regressor_batch, pred_vertices) pred_pelvis_smpl = pred_keypoints_3d_smpl[:, [0], :].clone() pred_keypoints_3d_smpl = pred_keypoints_3d_smpl[:, H36M_TO_J14, :] pred_keypoints_3d_smpl = pred_keypoints_3d_smpl - pred_pelvis_smpl file_name = data_batch['img_meta'].data[0][0]['file_name'] # Compute error metrics # Absolute error (MPJPE) error_smpl = torch.sqrt( ((pred_keypoints_3d_smpl - gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1) mpjpe = float(error_smpl.min() * 1000) self.p1_meter.update(mpjpe) if self.pattern in file_name: # Reconstruction error r_error_smpl = reconstruction_error( pred_keypoints_3d_smpl.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None) r_error = float(r_error_smpl.min() * 1000) self.p2_meter.update(r_error) else: r_error = -1 save_pack = { 'file_name': file_name, 'MPJPE': mpjpe, 'r_error': r_error, 'pred_rotmat': pred_results['pred_rotmat'], 'pred_betas': pred_results['pred_betas'], } return save_pack
def handle(self, data_batch, pred_results, use_gt=False): pred_vertices = pred_results['pred_vertices'].cpu() gt_keypoints_3d = data_batch['gt_kpts3d'].data[0][0].clone().repeat( [pred_vertices.shape[0], 1, 1]) gt_pelvis_smpl = gt_keypoints_3d[:, [14], :-1].clone() gt_keypoints_3d = gt_keypoints_3d[:, J24_TO_J14, :-1].clone() gt_keypoints_3d = gt_keypoints_3d - gt_pelvis_smpl J_regressor_batch = self.J_regressor[None, :].expand( pred_vertices.shape[0], -1, -1).to(pred_vertices.device) # Get 14 predicted joints from the SMPL mesh pred_keypoints_3d_smpl = torch.matmul(J_regressor_batch, pred_vertices) pred_pelvis_smpl = pred_keypoints_3d_smpl[:, [0], :].clone() pred_keypoints_3d_smpl = pred_keypoints_3d_smpl[:, H36M_TO_J14, :] pred_keypoints_3d_smpl = pred_keypoints_3d_smpl - pred_pelvis_smpl file_name = data_batch['img_meta'].data[0][0]['file_name'] # Compute error metrics # Absolute error (MPJPE) error_smpl = torch.sqrt( ((pred_keypoints_3d_smpl - gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1) mpjpe = float(error_smpl.min() * 1000) self.p1_meter.update(mpjpe) if self.pattern in file_name: # Reconstruction error r_error_smpl = reconstruction_error( pred_keypoints_3d_smpl.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None) r_error = float(r_error_smpl.min() * 1000) self.p2_meter.update(r_error) else: r_error = -1 save_pack = { 'file_name': file_name, 'MPJPE': mpjpe, 'r_error': r_error, 'pred_rotmat': pred_results['pred_rotmat'], 'pred_betas': pred_results['pred_betas'], } if self.viz_dir: file_name = data_batch['img_meta'].data[0][0]['file_name'] fname = osp.basename(file_name) bboxes = pred_results['bboxes'][0][:, :4] pred_translation = pred_results['pred_translation'].cpu() img = data_batch['img'].data[0][0].clone() img = img.clone() * torch.tensor([0.229, 0.224, 0.225]).view( 3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) img_cv = img.clone().numpy() img_cv = (img_cv * 255).astype(np.uint8).transpose([1, 2, 0]).copy() index = error_smpl.argmin() bbox = bboxes[index] img_cv = cv2.rectangle(img_cv, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2) img_cv = draw_text(img_cv, {'error': str(mpjpe)}) img_cv = (img_cv / 255.) torch.set_printoptions(precision=1) img_render = self.renderer( [torch.tensor(img_cv.transpose([2, 0, 1]))], [pred_vertices], translation=[pred_translation]) bv_verts = get_bv_verts(bboxes, pred_vertices, pred_translation, img.shape, self.FOCAL_LENGTH) img_bv = self.renderer( [torch.ones_like(img)], [bv_verts], translation=[torch.zeros(bv_verts.shape[0], 3)]) img_grid = torchvision.utils.make_grid(torch.tensor( ([img_render[0], img_bv[0]])), nrow=2).numpy().transpose( [1, 2, 0]) img_grid[img_grid > 1] = 1 img_grid[img_grid < 0] = 0 if not osp.exists(self.viz_dir): os.makedirs(self.viz_dir) plt.imsave(osp.join(self.viz_dir, fname), img_grid) return save_pack