Example #1
0
    def handle(self, data_batch, pred_results, use_gt=False):
        # Evaluate collision metric
        pred_vertices = pred_results['pred_vertices']
        pred_translation = pred_results['pred_translation']
        cur_collision_volume = self.collision_volume(pred_vertices, pred_translation)
        if cur_collision_volume.item() > 0:
            # self.writer(f'Collision found with {cur_collision_volume.item() * 1000} L')
            self.coll_cnt += 1
        self.collision_meter.update(cur_collision_volume.item() * 1000.)

        pred_vertices = pred_results['pred_vertices'].cpu()
        pred_camera = pred_results['pred_camera'].cpu()
        pred_translation = pred_results['pred_translation'].cpu()
        bboxes = pred_results['bboxes'][0][:, :4]
        img = data_batch['img'].data[0][0].clone()

        gt_keypoints_3d = data_batch['gt_kpts3d'].data[0][0].clone()
        gt_pelvis_smpl = gt_keypoints_3d[:, [14], :-1].clone()
        visible_kpts = gt_keypoints_3d[:, J24_TO_H36M, -1].clone()
        origin_gt_kpts3d = data_batch['gt_kpts3d'].data[0][0].clone().cpu()
        origin_gt_kpts3d = origin_gt_kpts3d[:, J24_TO_H36M]
        # origin_gt_kpts3d[:, :, :-1] -= gt_pelvis_smpl
        gt_keypoints_3d = gt_keypoints_3d[:, J24_TO_H36M, :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d - gt_pelvis_smpl

        J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(
            pred_vertices.device)
        # Get 14 predicted joints from the SMPL mesh
        pred_keypoints_3d_smpl = torch.matmul(J_regressor_batch, pred_vertices)
        pred_pelvis_smpl = pred_keypoints_3d_smpl[:, [0], :].clone()
        # pred_keypoints_3d_smpl = pred_keypoints_3d_smpl[:, H36M_TO_J14, :]
        pred_keypoints_3d_smpl = pred_keypoints_3d_smpl - pred_pelvis_smpl

        file_name = data_batch['img_meta'].data[0][0]['file_name']
        fname = osp.basename(file_name)

        # To select closest points
        glb_vis = (visible_kpts.sum(0) >= (
                visible_kpts.shape[0] - 0.1)).float()[None, :, None]  # To avoid in-accuracy in float point number
        if use_gt:
            paired_idxs = torch.arange(gt_keypoints_3d.shape[0])
        else:
            dist = vectorize_distance((glb_vis * gt_keypoints_3d).numpy(),
                                      (glb_vis * pred_keypoints_3d_smpl).numpy())
            paired_idxs = torch.from_numpy(dist.argmin(1))
        is_mismatch = len(set(paired_idxs.tolist())) < len(paired_idxs)
        if is_mismatch:
            self.mismatch_cnt += 1

        selected_prediction = pred_keypoints_3d_smpl[paired_idxs]

        # Compute error metrics
        # Absolute error (MPJPE)
        error_smpl = (torch.sqrt(((selected_prediction - gt_keypoints_3d) ** 2).sum(dim=-1)) * visible_kpts)

        mpjpe = float(error_smpl.mean() * 1000)
        self.p1_meter.update(mpjpe, n=error_smpl.shape[0])

        save_pack = {'file_name': osp.basename(file_name),
                     'MPJPE': mpjpe,
                     'pred_rotmat': pred_results['pred_rotmat'].cpu(),
                     'pred_betas': pred_results['pred_betas'].cpu(),
                     'gt_kpts': origin_gt_kpts3d,
                     'kpts_paired': selected_prediction,
                     'pred_kpts': pred_keypoints_3d_smpl,
                     }

        if self.viz_dir and (is_mismatch or error_smpl.mean(-1).min() * 1000 > 200):
            img = img.clone() * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor(
                [0.485, 0.456, 0.406]).view(3, 1, 1)
            img_cv = img.clone().numpy()
            img_cv = (img_cv * 255).astype(np.uint8).transpose([1, 2, 0]).copy()
            for bbox in bboxes[paired_idxs]:
                img_cv = cv2.rectangle(img_cv, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
            img_cv = draw_text(img_cv, {'mismatch': is_mismatch, 'error': str(error_smpl.mean(-1) * 1000)});
            img_cv = (img_cv / 255.)

            torch.set_printoptions(precision=1)
            img_render = self.renderer([torch.tensor(img_cv.transpose([2, 0, 1]))], [pred_vertices],
                                       translation=[pred_translation])

            bv_verts = get_bv_verts(bboxes, pred_vertices, pred_translation,
                                    img.shape, self.FOCAL_LENGTH)
            img_bv = self.renderer([torch.ones_like(img)], [bv_verts],
                                   translation=[torch.zeros(bv_verts.shape[0], 3)])
            img_grid = torchvision.utils.make_grid(torch.tensor(([img_render[0], img_bv[0]])),
                                                   nrow=2).numpy().transpose([1, 2, 0])
            img_grid[img_grid > 1] = 1
            img_grid[img_grid < 0] = 0
            plt.imsave(osp.join(self.viz_dir, fname), img_grid)
        return save_pack
Example #2
0
    def handle(self, data_batch, pred_results, use_gt=False):
        pred_vertices = pred_results['pred_vertices'].cpu()

        gt_keypoints_3d = data_batch['gt_kpts3d'].data[0][0].clone().repeat(
            [pred_vertices.shape[0], 1, 1])
        gt_pelvis_smpl = gt_keypoints_3d[:, [14], :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d[:, J24_TO_J14, :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d - gt_pelvis_smpl

        J_regressor_batch = self.J_regressor[None, :].expand(
            pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
        # Get 14 predicted joints from the SMPL mesh
        pred_keypoints_3d_smpl = torch.matmul(J_regressor_batch, pred_vertices)
        pred_pelvis_smpl = pred_keypoints_3d_smpl[:, [0], :].clone()
        pred_keypoints_3d_smpl = pred_keypoints_3d_smpl[:, H36M_TO_J14, :]
        pred_keypoints_3d_smpl = pred_keypoints_3d_smpl - pred_pelvis_smpl

        file_name = data_batch['img_meta'].data[0][0]['file_name']

        # Compute error metrics
        # Absolute error (MPJPE)
        error_smpl = torch.sqrt(
            ((pred_keypoints_3d_smpl -
              gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1)

        mpjpe = float(error_smpl.min() * 1000)
        self.p1_meter.update(mpjpe)

        if self.pattern in file_name:
            # Reconstruction error
            r_error_smpl = reconstruction_error(
                pred_keypoints_3d_smpl.cpu().numpy(),
                gt_keypoints_3d.cpu().numpy(),
                reduction=None)
            r_error = float(r_error_smpl.min() * 1000)
            self.p2_meter.update(r_error)
        else:
            r_error = -1

        save_pack = {
            'file_name': file_name,
            'MPJPE': mpjpe,
            'r_error': r_error,
            'pred_rotmat': pred_results['pred_rotmat'],
            'pred_betas': pred_results['pred_betas'],
        }

        if self.viz_dir:
            file_name = data_batch['img_meta'].data[0][0]['file_name']
            fname = osp.basename(file_name)
            bboxes = pred_results['bboxes'][0][:, :4]
            pred_translation = pred_results['pred_translation'].cpu()
            img = data_batch['img'].data[0][0].clone()
            img = img.clone() * torch.tensor([0.229, 0.224, 0.225]).view(
                3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            img_cv = img.clone().numpy()
            img_cv = (img_cv * 255).astype(np.uint8).transpose([1, 2,
                                                                0]).copy()

            index = error_smpl.argmin()
            bbox = bboxes[index]
            img_cv = cv2.rectangle(img_cv, (bbox[0], bbox[1]),
                                   (bbox[2], bbox[3]), (255, 0, 0), 2)
            img_cv = draw_text(img_cv, {'error': str(mpjpe)})
            img_cv = (img_cv / 255.)

            torch.set_printoptions(precision=1)
            img_render = self.renderer(
                [torch.tensor(img_cv.transpose([2, 0, 1]))], [pred_vertices],
                translation=[pred_translation])

            bv_verts = get_bv_verts(bboxes, pred_vertices, pred_translation,
                                    img.shape, self.FOCAL_LENGTH)
            img_bv = self.renderer(
                [torch.ones_like(img)], [bv_verts],
                translation=[torch.zeros(bv_verts.shape[0], 3)])
            img_grid = torchvision.utils.make_grid(torch.tensor(
                ([img_render[0], img_bv[0]])),
                                                   nrow=2).numpy().transpose(
                                                       [1, 2, 0])
            img_grid[img_grid > 1] = 1
            img_grid[img_grid < 0] = 0
            if not osp.exists(self.viz_dir):
                os.makedirs(self.viz_dir)
            plt.imsave(osp.join(self.viz_dir, fname), img_grid)

        return save_pack