def init_rtvec_train(BATCH_SIZE, device):
    rtvec_gt = np.random.normal(0, 0.15, (BATCH_SIZE, 6))
    rtvec_gt[:, :3] = rtvec_gt[:, :3] * 0.35 * PI

    rtvec_smp = np.random.normal(0, 0.15, (BATCH_SIZE, 6))
    rtvec_smp[:, :3] = rtvec_smp[:, :3] * 0.35 * PI

    rtvec = rtvec_smp

    rtvec_torch = torch.tensor(rtvec,
                               dtype=torch.float,
                               requires_grad=True,
                               device=device)
    rtvec_gt_torch = torch.tensor(rtvec_gt,
                                  dtype=torch.float,
                                  requires_grad=True,
                                  device=device)

    rot_mat = euler2mat(rtvec_torch[:, :3])
    angle_axis = tgm.rotation_matrix_to_angle_axis(
        torch.cat([rot_mat, torch.zeros(BATCH_SIZE, 3, 1).to(device)], dim=-1))
    rtvec = torch.cat([angle_axis, rtvec_torch[:, 3:]], dim=-1)

    rot_mat_gt = euler2mat(rtvec_gt_torch[:, :3])
    angle_axis_gt = tgm.rotation_matrix_to_angle_axis(
        torch.cat(
            [rot_mat_gt, torch.zeros(BATCH_SIZE, 3, 1).to(device)], dim=-1))
    rtvec_gt = torch.cat([angle_axis_gt, rtvec_gt_torch[:, 3:]], dim=-1)
    transform_mat4x4_gt = tgm.rtvec_to_pose(rtvec_gt)
    transform_mat3x4_gt = transform_mat4x4_gt[:, :3, :]

    return transform_mat3x4_gt, rtvec, rtvec_gt
    def forward(self, x):
        fc1_out = F.leaky_relu(self.fc1(x), 0.2)
        fc1_out = F.dropout(fc1_out, p=0.25, training=self.is_training)
        fc2_out = F.leaky_relu(self.fc2(fc1_out), 0.2)
        fc3_out = torch.tanh(self.fc3(fc2_out))
        #  output = torch.reshape(fc3_out, shape=(-1, 23, 3, 3))[:, :-2, :]
        output = torch.reshape(fc3_out, shape=(-1, 3, 3))
        batch_size = x.shape[0]

        # Before converting the output rotation matrices of the VAE to
        # axis-angle representation, we first need to make them in to valid
        # rotation matrices
        with torch.no_grad():
            # Iterate over the batch dimension and compute the SVD
            norm_rotation = torch.zeros_like(output)
            for bidx in range(output.shape[0]):
                U, _, V = torch.svd(output[bidx])
                # Multiply the U, V matrices to get the closest orthonormal
                # matrix
                norm_rotation[bidx] = torch.matmul(U, V.t())

        # torch.svd supports backprop only for full-rank matrices.
        # The output is calculated as the valid rotation matrix plus the
        # output minus the detached output. If one writes down the
        # computational graph for this operation, it will become clear the the
        # output is the desired valid rotation matrix, while for the backward
        # pass gradients are propagated only to the original matrix
        # Source: PyTorch Gumbel-Softmax hard sampling
        correct_rot = norm_rotation - output.detach() + output

        return tgm.rotation_matrix_to_angle_axis(
            F.pad(correct_rot.view(-1, 3, 3), [0, 1, 0, 0])).view(batch_size, -1)
Exemplo n.º 3
0
 def matrot2aa(pose_matrot):
     '''
     :param pose_matrot: Nx1xnum_jointsx9
     :return: Nx1xnum_jointsx3
     '''
     
     homogen_matrot = F.pad(pose_matrot.view(-1, 3, 3), [0,1])
     pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot).view(-1, 3).contiguous()
     return pose
Exemplo n.º 4
0
def rotmat2aa(rotmat):
    '''
    :param rotmat: Nx1xnum_jointsx9
    :return: Nx1xnum_jointsx3
    '''
    batch_size = rotmat.size(0)
    homogen_matrot = F.pad(rotmat.view(-1, 3, 3), [0, 1])
    pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot).view(
        batch_size, 1, -1, 3).contiguous()
    return pose
Exemplo n.º 5
0
def create_pinhole(intrinsic, extrinsic, height, width):
    pinhole = torch.zeros(12)
    pinhole[0] = intrinsic[0, 0]  # fx
    pinhole[1] = intrinsic[1, 1]  # fy
    pinhole[2] = intrinsic[0, 2]  # cx
    pinhole[3] = intrinsic[1, 2]  # cy
    pinhole[4] = height
    pinhole[5] = width
    pinhole[6:9] = tgm.rotation_matrix_to_angle_axis(torch.tensor(extrinsic))
    pinhole[9:12] = torch.tensor(extrinsic[:, 3])
    return pinhole.view(1, -1)
Exemplo n.º 6
0
    def forward(self, pose_3d):
        pose_3d = pose_3d.view(-1, self.joint_num * 3)
        feat = self.fc(pose_3d)

        pose = self.fc_pose(feat)
        pose = self.rot6d_to_rotmat(pose)
        pose = torch.cat(
            [pose, torch.zeros((pose.shape[0], 3, 1)).cuda().float()], 2)
        pose = tgm.rotation_matrix_to_angle_axis(pose).reshape(-1, 72)

        shape = self.fc_shape(feat)

        return pose, shape
Exemplo n.º 7
0
    def test_rotation_matrix_to_angle_axis(self):
        rmat_1 = torch.tensor([[-0.30382753, -0.95095137, -0.05814062, 0.],
                               [-0.71581715, 0.26812278, -0.64476041, 0.],
                               [0.62872461, -0.15427791, -0.76217038, 0.]])
        rvec_1 = torch.tensor([1.50485376, -2.10737739, 0.7214174])

        rmat_2 = torch.tensor([[0.6027768, -0.79275544, -0.09054801, 0.],
                               [-0.67915707, -0.56931658, 0.46327563, 0.],
                               [-0.41881476, -0.21775548, -0.88157628, 0.]])
        rvec_2 = torch.tensor([-2.44916812, 1.18053411, 0.4085298])
        rmat = torch.stack([rmat_2, rmat_1], dim=0)
        rvec = torch.stack([rvec_2, rvec_1], dim=0)
        self.assertTrue(
            check_equal_torch(tgm.rotation_matrix_to_angle_axis(rmat), rvec))
Exemplo n.º 8
0
def rotmat_to_angleaxis(init_pred_rotmat):
    """
        init_pred_rotmat: torch.tensor with (24,3,3) dimension
    """
    device = init_pred_rotmat.device
    ones = torch.tensor(
        [0, 0, 1],
        dtype=torch.float32,
    ).view(1, 3, 1).expand(init_pred_rotmat.shape[1], -1, -1).to(device)

    pred_rotmat_hom = torch.cat([init_pred_rotmat.view(-1, 3, 3), ones],
                                dim=-1)  #24,3,4
    pred_aa = torchgeometry.rotation_matrix_to_angle_axis(
        pred_rotmat_hom).contiguous().view(1, -1)  #[1,72]
    # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
    pred_aa[torch.isnan(pred_aa)] = 0.0  #[1,72]
    pred_aa = pred_aa.view(1, 24, 3)

    return pred_aa
Exemplo n.º 9
0
def rotmat3x3_to_angleaxis(init_pred_rotmat):
    """
        init_pred_rotmat: torch.tensor with (1, N,3,3) dimension
        output: (1, N,3)
    """
    assert init_pred_rotmat.shape[
        0] == 1, "Sould be fixed to handle general batch size.. not confirmed yet"

    device = init_pred_rotmat.device
    jointNum = init_pred_rotmat.shape[1]
    ones = torch.tensor(
        [0, 0, 1],
        dtype=torch.float32,
    ).view(1, 3, 1).expand(jointNum, -1, -1).to(device)

    pred_rotmat_hom = torch.cat([init_pred_rotmat.view(-1, 3, 3), ones],
                                dim=-1)  #24,3,4
    pred_aa = tgm.rotation_matrix_to_angle_axis(
        pred_rotmat_hom).contiguous().view(1, -1)  #[1,72]
    # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
    pred_aa[torch.isnan(pred_aa)] = 0.0  #[1,72]
    pred_aa = pred_aa.view(1, jointNum, 3)

    return pred_aa
Exemplo n.º 10
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50, bVerbose= True):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # # Transfer model to the GPU
    # model.to(device)

    # Load SMPL model
    global g_smpl_neutral, g_smpl_male, g_smpl_female
    if g_smpl_neutral is None:
        g_smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                            create_transl=False).to(device)
        g_smpl_male = SMPL(config.SMPL_MODEL_DIR,
                        gender='male',
                        create_transl=False).to(device)
        g_smpl_female = SMPL(config.SMPL_MODEL_DIR,
                        gender='female',
                        create_transl=False).to(device)

        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female
    else:
        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female

    
    # renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle=False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    # mpjpe = np.zeros(len(dataset))
    # recon_err = np.zeros(len(dataset))
    quant_mpjpe = {}#np.zeros(len(dataset))
    quant_recon_err = {}#np.zeros(len(dataset))
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))

    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2,1))
    fp = np.zeros((2,1))
    fn = np.zeros((2,1))
    parts_tp = np.zeros((7,1))
    parts_fp = np.zeros((7,1))
    parts_fn = np.zeros((7,1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    output_pred_pose = np.zeros((len(dataset), 72))
    output_pred_betas = np.zeros((len(dataset), 10))
    output_pred_camera = np.zeros((len(dataset), 3))
    output_pred_joints = np.zeros((len(dataset), 14, 3))

    output_gt_pose = np.zeros((len(dataset), 72))
    output_gt_betas = np.zeros((len(dataset), 10))
    output_gt_joints = np.zeros((len(dataset), 14, 3))

    output_error_MPJPE = np.zeros((len(dataset)))
    output_error_recon = np.zeros((len(dataset)))

    output_imgNames =[]
    output_cropScale  = np.zeros((len(dataset)))
    output_cropCenter = np.zeros((len(dataset), 2))
    outputStartPointer = 0


    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch

        imgName = batch['imgname'][0]
        seqName = os.path.basename ( os.path.dirname(imgName) )

        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]
        
        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

        
    
        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices_female = smpl_female(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices[gender==1, :, :] = gt_vertices_female[gender==1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0],:].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis             

                if False:
                    from renderer import viewer2D
                    from renderer import glViewer
                    import humanModelViewer
                    batchNum = gt_pose.shape[0]
                    for i in range(batchNum):
                        smpl_face = humanModelViewer.GetSMPLFace()
                        meshes_gt = {'ver': gt_vertices[i].cpu().numpy()*100, 'f': smpl_face}
                        meshes_pred = {'ver': pred_vertices[i].cpu().numpy()*100, 'f': smpl_face}

                        glViewer.setMeshData([meshes_gt, meshes_pred], bComputeNormal= True)
                        glViewer.show(5)

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            # if save_results:
            #     pred_joints[step * batch_size:step * batch_size + curr_batch_size, :, :]  = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0],:].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis 

            #Visualize GT mesh and SPIN output mesh
            if False:
                from renderer import viewer2D
                from renderer import glViewer
                import humanModelViewer

                gt_keypoints_3d_vis = gt_keypoints_3d.cpu().numpy()
                gt_keypoints_3d_vis = np.reshape(gt_keypoints_3d_vis, (gt_keypoints_3d_vis.shape[0],-1))        #N,14x3
                gt_keypoints_3d_vis = np.swapaxes(gt_keypoints_3d_vis, 0,1) *100

                pred_keypoints_3d_vis = pred_keypoints_3d.cpu().numpy()
                pred_keypoints_3d_vis = np.reshape(pred_keypoints_3d_vis, (pred_keypoints_3d_vis.shape[0],-1))        #N,14x3
                pred_keypoints_3d_vis = np.swapaxes(pred_keypoints_3d_vis, 0,1) *100
                # output_sample = output_sample[ : , np.newaxis]*0.1
                # gt_sample = gt_sample[: , np.newaxis]*0.1
                # (skelNum, dim, frames)
                glViewer.setSkeleton( [gt_keypoints_3d_vis, pred_keypoints_3d_vis] ,jointType='smplcoco')#(skelNum, dim, frames)
                glViewer.show()
                

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            # mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None)
            # recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            for ii, p in enumerate(batch['imgname'][:len(r_error)]):
                seqName = os.path.basename( os.path.dirname(p))
                # quant_mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error
                if seqName not in quant_mpjpe.keys():
                    quant_mpjpe[seqName] = []
                    quant_recon_err[seqName] = []
                
                quant_mpjpe[seqName].append(error[ii]) 
                quant_recon_err[seqName].append(r_error[ii])

            # Reconstuction_error
            # quant_recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            list_mpjpe = np.hstack([ quant_mpjpe[k] for k in quant_mpjpe])
            list_reconError = np.hstack([ quant_recon_err[k] for k in quant_recon_err])
            if bVerbose:
                print(">>> {} : MPJPE {:.02f} mm, error: {:.02f} mm | Total MPJPE {:.02f} mm, error {:.02f} mm".format(seqName, np.mean(error)*1000, np.mean(r_error)*1000, np.hstack(list_mpjpe).mean()*1000, np.hstack(list_reconError).mean()*1000) )

            # print("MPJPE {}, error: {}".format(np.mean(error)*100, np.mean(r_error)*100))

        # If mask or part evaluation, render the mask and part images
        # if eval_masks or eval_parts:
        #     mask, parts = renderer(pred_vertices, pred_camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] +=  (~cgt & cpred).sum()
                    fn[c] +=  (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                   cgt = gt_parts == c
                   cpred = pred_parts == c
                   cpred[gt_parts == 255] = 0
                   parts_tp[c] += (cgt & cpred).sum()
                   parts_fp[c] +=  (~cgt & cpred).sum()
                   parts_fn[c] +=  (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if bVerbose:
            if step % log_freq == log_freq - 1:
                if eval_pose:
                    print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                    print('Reconstruction Error: ' + str(1000 * recon_err[:step * batch_size].mean()))
                    print()
                if eval_masks:
                    print('Accuracy: ', accuracy / pixel_count)
                    print('F1: ', f1.mean())
                    print()
                if eval_parts:
                    print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                    print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
                    print()

        if save_results:
            rot_pad = torch.tensor([0,0,1], dtype=torch.float32, device=device).view(1,3,1)
            rotmat = torch.cat((pred_rotmat.view(-1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)), dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(rotmat).contiguous().view(-1, 72)

            output_pred_pose[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_pose.cpu().numpy()
            output_pred_betas[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_betas.cpu().numpy()
            output_pred_camera[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_camera.cpu().numpy()

            output_pred_pose[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_pose.cpu().numpy()
            output_pred_betas[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_betas.cpu().numpy()
            output_pred_camera[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_camera.cpu().numpy()
            output_pred_joints[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_keypoints_3d.cpu().numpy()

            output_gt_pose[outputStartPointer:outputStartPointer+curr_batch_size, :]  = gt_pose.cpu().numpy()
            output_gt_betas[outputStartPointer:outputStartPointer+curr_batch_size, :] = gt_betas.cpu().numpy()
            output_gt_joints[outputStartPointer:outputStartPointer+curr_batch_size, :] = gt_keypoints_3d.cpu().numpy()

            output_error_MPJPE[outputStartPointer:outputStartPointer+curr_batch_size,]  =  error *1000
            output_error_recon[outputStartPointer:outputStartPointer+curr_batch_size] =  r_error*1000

            output_cropScale[outputStartPointer:outputStartPointer+curr_batch_size] = batch['scale'].cpu().numpy()
            output_cropCenter[outputStartPointer:outputStartPointer+curr_batch_size, :] = batch['center'].cpu().numpy()

            output_imgNames +=batch['imgname']

            outputStartPointer +=curr_batch_size

            # if outputStartPointer>100:     #Debug
            #         break


        
    # if len(output_imgNames) < output_pred_pose.shape[0]:
    output ={}
    finalLen = len(output_imgNames)
    output['imageNames'] = output_imgNames
    output['pred_pose'] = output_pred_pose[:finalLen]
    output['pred_betas'] = output_pred_betas[:finalLen]
    output['pred_camera'] = output_pred_camera[:finalLen]
    output['pred_joints'] = output_pred_joints[:finalLen]

    output['gt_pose'] = output_gt_pose[:finalLen]
    output['gt_betas'] = output_gt_betas[:finalLen]
    output['gt_joints'] = output_gt_joints[:finalLen]

    output['error_MPJPE'] = output_error_MPJPE[:finalLen]
    output['error_recon'] = output_error_recon[:finalLen]

    output['cropScale']  = output_cropScale[:finalLen]
    output['cropCenter'] = output_cropCenter[:finalLen]


    # Save reconstructions to a file for further processing
    if save_results:
        import pickle
        # np.savez(result_file, pred_joints=pred_joints, pred_pose=pred_pose, pred_betas=pred_betas, pred_camera=pred_camera)
        with open(result_file,'wb') as f:
            pickle.dump(output, f)
            f.close()
            print("Saved to:{}".format(result_file))
        
    # Print final results during evaluation

    if bVerbose:
        print('*** Final Results ***')
        print()
    if eval_pose:
        # if bVerbose:
        #     print('MPJPE: ' + str(1000 * mpjpe.mean()))
        #     print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        #     print()
        list_mpjpe = np.hstack([ quant_mpjpe[k] for k in quant_mpjpe])
        list_reconError = np.hstack([ quant_recon_err[k] for k in quant_recon_err])

        output_str ='SeqNames; '
        for seq in quant_mpjpe:
            output_str += seq + ';'
        output_str +='\n MPJPE; '
        quant_mpjpe_avg_mm = np.hstack(list_mpjpe).mean()*1000
        output_str += "Avg {:.02f} mm; ".format( quant_mpjpe_avg_mm)
        for seq in quant_mpjpe:
            output_str += '{:.02f}; '.format(1000 * np.hstack(quant_mpjpe[seq]).mean())

        output_str +='\n Recon Error; '
        quant_recon_error_avg_mm = np.hstack(list_reconError).mean()*1000
        output_str +="Avg {:.02f}mm; ".format( quant_recon_error_avg_mm )
        for seq in quant_recon_err:
            output_str += '{:.02f}; '.format(1000 * np.hstack(quant_recon_err[seq]).mean())
        if bVerbose:
            print(output_str)
        else:
            print(">>>  Test on 3DPW: MPJPE: {} | quant_recon_error_avg_mm: {}".format(quant_mpjpe_avg_mm, quant_recon_error_avg_mm) )

       
        return quant_mpjpe_avg_mm, quant_recon_error_avg_mm

    if bVerbose:
        if eval_masks:
            print('Accuracy: ', accuracy / pixel_count)
            print('F1: ', f1.mean())
            print()
        if eval_parts:
            print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
            print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
            print()

    return -1       #Should return something
Exemplo n.º 11
0
def rot_to_axisang(rot):
    zeros = torch.zeros(rot.shape[0], 3, 1).to(rot.device)
    rot = torch.cat([rot, zeros], dim=-1)
    return rotation_matrix_to_angle_axis(rot)
Exemplo n.º 12
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)
    
    renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle=False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2,1))
    fp = np.zeros((2,1))
    fn = np.zeros((2,1))
    parts_tp = np.zeros((7,1))
    parts_fp = np.zeros((7,1))
    parts_fn = np.zeros((7,1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]
        
        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0,0,1], dtype=torch.float32, device=device).view(1,3,1)
            rotmat = torch.cat((pred_rotmat.view(-1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)), dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size + curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size + curr_batch_size, :]  = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size + curr_batch_size, :]  = pred_camera.cpu().numpy()
            
        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices_female = smpl_female(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices[gender==1, :, :] = gt_vertices_female[gender==1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0],:].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis 


            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[step * batch_size:step * batch_size + curr_batch_size, :, :]  = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0],:].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis 

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None)
            recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error


        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] +=  (~cgt & cpred).sum()
                    fn[c] +=  (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                   cgt = gt_parts == c
                   cpred = pred_parts == c
                   cpred[gt_parts == 255] = 0
                   parts_tp[c] += (cgt & cpred).sum()
                   parts_fp[c] +=  (~cgt & cpred).sum()
                   parts_fn[c] +=  (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' + str(1000 * recon_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
        print()
Exemplo n.º 13
0
def run_evaluation(model,
                   dataset_name,
                   dataset,
                   result_file,
                   batch_size=1,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   log_freq=50,
                   bVerbose=True):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    # # Transfer model to the GPU
    # model.to(device)

    # Load SMPL model
    global g_smpl_neutral, g_smpl_male, g_smpl_female
    if g_smpl_neutral is None:
        g_smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                              create_transl=False).to(device)
        g_smpl_male = SMPL(config.SMPL_MODEL_DIR,
                           gender='male',
                           create_transl=False).to(device)
        g_smpl_female = SMPL(config.SMPL_MODEL_DIR,
                             gender='female',
                             create_transl=False).to(device)

        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female
    else:
        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female

    # renderer = PartRenderer()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(
        config.JOINT_REGRESSOR_H36M)).float()

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    # mpjpe = np.zeros(len(dataset))
    # recon_err = np.zeros(len(dataset))
    quant_mpjpe = {}  #np.zeros(len(dataset))
    quant_recon_err = {}  #np.zeros(len(dataset))
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))

    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == '3dpw-vibe' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    # cnt =0
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch

        # imgName = batch['imgname'][0]
        # seqName = os.path.basename ( os.path.dirname(imgName) )

        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas,
                                   body_pose=gt_pose[:, 3:],
                                   global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]

        bLoadFromFile = True
        missingPkl = False
        if bLoadFromFile:
            # pklDir= '/run/media/hjoo/disk/data/cvpr2020_eft_researchoutput/0_SPIN/0_exemplarOutput/04-22_3dpw_test_with8143_iter10'
            pklDir = '/private/home/hjoo/spinOut/05-11_3dpw_test_with1336_iter5'
            pklDir = '/private/home/hjoo/spinOut/05-11_3dpw_test_with1039_iter5'
            pklDir = '/private/home/hjoo/spinOut/05-11_3dpw_test_with1336_iter10'
            pklDir = '/private/home/hjoo/spinOut/05-11_3dpw_test_with1336_iter3'
            pklDir = '/run/media/hjoo/disk/data/cvpr2020_eft_researchoutput/0_SPIN/0_exemplarOutput/05-24_3dpw_test_with1336_iterUpto20'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_with1336_iterUpto50_thr2e4'

            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_smplify_3dpwtest_from1336'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_smplify_3dpwtest_from7640'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_smplify_3dpwtest_from5992'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_with1644_h36m_thr2e4'

            #New test with LSP Init
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_smplify_3dpwtest_from732_lsp'
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_with732_lsp_withHips'
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_with732_lsp_noHips'

            #New test with MPII start

            #SMPLify
            pklDir = '/private/home/hjoo/spinOut/05-28_3dpw_test_smplify_3dpwtest_from3097_best'
            pklDir = '/private/home/hjoo/spinOut/05-31_3dpw_test_smplify_3dpwTest_bestW3DPW_from8653'
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_smplify_3dpwtest_from35_mpii'

            #EFT
            pklDir = '/private/home/hjoo/spinOut/05-28_3dpw_test_with35_mpii_noHips'
            pklDir = '/private/home/hjoo/spinOut/05-31_3dpw_test_3dpwTest_bestW3DPW_from8653'
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_with35_mpii_withHips'
            pklDir = '/private/home/hjoo/spinOut/05-27_3dpw_test_with35_mpii_noHips'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_with7640_iterUpto50_thr2e4'
            pklDir = '/private/home/hjoo/spinOut/05-25_3dpw_test_with5992_iterUpto50_thr2e4'
            pklDir = '/private/home/hjoo/spinOut/05-28_3dpw_test_byeft_with3097_best_noHips'

            #Rebuttal Additional Ablation
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_layer4only'  #Layer 4 only
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_afterResOnly'  #HMR FC part only
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_allResLayers'  #Res Layers
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_layer4andLayer'  #layer4 + HMR FC part

            #Rebuttal Additional Ablation (more)
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_layerteset_onlyRes_withconv1'  #All resnet
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_layerteset_all'  #no freezing
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_layerteset_decOnly'  #The last regression layer
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_layerteset_fc2Later'  #The last regression layer

            #Restart Some verification
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_ablation_noFreez'  #The last regression layer
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_again_noFreez'  #Original. No freeze. For debug
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_all'  #Original. No freeze. For debug

            #Rebuttal: Real Ablation
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_onlyRes_withconv1'  #Optimizing Res50. Freeze HMR FC
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_onlyAfterRes'  #Optimizing HMR FC part. Freeze Res50

            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_decOnly'  #Free all except the last layer of HMR
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_onlyLayer4'  #Optimzing only Layer4 of ResNet

            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_fc2Later'  #HMR FC2 and layer
            pklDir = '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_from8653_ablation_layerteset_onlyRes50LastConv'  #Last Conv of Res50

            #Ablation: SMPLify
            pklDir = '/private/home/hjoo/spinOut/08-10_3dpw_test_smplify_abl_3dpwTest_from8653_noPrior'
            pklDir = '/private/home/hjoo/spinOut/08-10_3dpw_test_smplify_abl_3dpwTest_from8653_noCamFirst'
            pklDir = '/private/home/hjoo/spinOut/08-10_3dpw_test_smplify_abl_3dpwTest_from8653_noCamFirst_noPrior'
            pklDir = '/private/home/hjoo/spinOut/08-10_3dpw_test_smplify_abl_3dpwTest_from8653_noAnglePrior'
            pklDir = '/private/home/hjoo/spinOut/08-10_3dpw_test_smplify_abl_3dpwTest_from8653_noPosePrior'

            #CVPR 2021. New Start (old 3DPW)
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_with7640_coco3d'
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_with8377_cocoAl_h36_inf_3dpw'
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_with6814_cocoAl_h36_inf'
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_with5992_cocoAl'
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_35_mpii'
            pklDir = '/private/home/hjoo/spinOut/10-31_3dpw_test_1644_h36m'

            #CVPR 2021. New Start (old 3DPW)
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_with8377_cocoAl_h36_inf_3dpw'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_with6814_cocoAl_h36_inf'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_with7640_coco3d'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_732_lsp'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_35_mpii'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_with5992_cocoAl'
            pklDir = '/private/home/hjoo/spinOut/11-01_3dpw_test__vibe_1644_h36m'

            #Eval
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with7640_cocopart_iter100'
            # pklDir= '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with7640_cocopart_iter50'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with5992_cocoall3d_iter100'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with8377_cocoall3d_h36m_inf_3dpw_iter100'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with6814_cocoall3d_h36m_inf_iter100'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with35_mpii3d_iter100'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with1644_h36m_iter100'
            pklDir = '/private/home/hjoo/spinOut/11-02_3dpw_test_smplify_with732_lspet_iter100'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_hmr/croplev_4'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_hmr/croplev_2'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_hmr/croplev_1'

            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_partial/croplev_1'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_partial/croplev_2'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_partial/croplev_4'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_partial/croplev_7'
            pklDir = '/checkpoint/hjoo/data/3dpw_crop/output_hmr/croplev_7'

            # pklDir= '/private/home/hjoo/spinOut/05-31_3dpw_test_3dpwTest_bestW3DPW_from8653'
            # pklDir= '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_ablation_noFreez20'       #The last regression layer
            # pklDir= '/private/home/hjoo/spinOut/08-08_3dpw_test_abl_3dpwTest_from8653_ablation_layerteset_Layer2Later'       #FC2 layer and later

            # 08-08_3dpw_test_3dpwTest_bestW3DPW_from8653
            #
            #
            # 08-08_3dpw_test_abl_3dpwTest_from8653_layer4only

            pred_rotmat_list, pred_betas_list, pred_camera_list = [], [], []
            for subjectid, name in zip(batch['subjectId'], batch['imgname']):

                seqName = os.path.basename(os.path.dirname(name))

                subjectid = subjectid.item()
                imgNameOnly = os.path.basename(name)[:-4]
                pklfilename = f'{seqName}_{imgNameOnly}_pid{subjectid}.pkl'
                # pklfilename ='downtown_arguing_00_image_00049_pid0.pkl'
                pklfilepath = os.path.join(pklDir, pklfilename)
                # cnt+=1

                # assert os.path.exists(pklfilepath)
                if os.path.exists(pklfilepath) == False:
                    missingPkl = True
                    # print(f"Missing file: {pklfilepath}")
                    continue
                    # break
                else:
                    with open(pklfilepath, 'rb') as f:
                        data = pkl.load(f, encoding='latin1')

                        cam = data['cam'][0]
                        # cam = data['theta'][0,:3]
                        shape = data['theta'][0, -10:]
                        pose_aa = data['theta'][0, 3:-10]
                        pose_rotmat = angle_axis_to_rotation_matrix(
                            torch.from_numpy(pose_aa).view(24, 3))  #24,4,4
                        pose_rotmat = pose_rotmat[:, :3, :3].numpy()  #24,3,3

                        pred_rotmat_list.append(pose_rotmat[None, :])
                        pred_betas_list.append(shape[None, :])
                        pred_camera_list.append(cam[None, :])

                pred_rotmat = torch.from_numpy(
                    np.concatenate(pred_rotmat_list, axis=0)).cuda()
                pred_betas = torch.from_numpy(
                    np.concatenate(pred_betas_list, axis=0)).cuda()
                pred_camera = torch.from_numpy(
                    np.concatenate(pred_camera_list, axis=0)).cuda()

        # continue

        if missingPkl:
            with torch.no_grad():
                assert False
                pred_rotmat, pred_betas, pred_camera = model(images)

        pred_output = smpl_neutral(betas=pred_betas,
                                   body_pose=pred_rotmat[:, 1:],
                                   global_orient=pred_rotmat[:,
                                                             0].unsqueeze(1),
                                   pose2rot=False)
        pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0, 0, 1],
                                   dtype=torch.float32,
                                   device=device).view(1, 3, 1)
            rotmat = torch.cat((pred_rotmat.view(
                -1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)),
                               dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(
                rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size +
                      curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size +
                       curr_batch_size, :] = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size +
                        curr_batch_size, :] = pred_camera.cpu().numpy()

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:, :3],
                                        body_pose=gt_pose[:, 3:],
                                        betas=gt_betas).vertices
                gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3],
                                                 body_pose=gt_pose[:, 3:],
                                                 betas=gt_betas).vertices
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender ==
                                                                    1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

                if False:
                    from renderer import viewer2D
                    from renderer import glViewer
                    import humanModelViewer
                    batchNum = gt_pose.shape[0]
                    for i in range(batchNum):
                        smpl_face = humanModelViewer.GetSMPLFace()
                        meshes_gt = {
                            'ver': gt_vertices[i].cpu().numpy() * 100,
                            'f': smpl_face
                        }
                        meshes_pred = {
                            'ver': pred_vertices[i].cpu().numpy() * 100,
                            'f': smpl_face
                        }

                        glViewer.setMeshData([meshes_gt, meshes_pred],
                                             bComputeNormal=True)
                        glViewer.show(5)

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[
                    step * batch_size:step * batch_size +
                    curr_batch_size, :, :] = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            #Visualize GT mesh and SPIN output mesh
            if False:
                from renderer import viewer2D
                from renderer import glViewer
                import humanModelViewer

                gt_keypoints_3d_vis = gt_keypoints_3d.cpu().numpy()
                gt_keypoints_3d_vis = np.reshape(
                    gt_keypoints_3d_vis,
                    (gt_keypoints_3d_vis.shape[0], -1))  #N,14x3
                gt_keypoints_3d_vis = np.swapaxes(gt_keypoints_3d_vis, 0,
                                                  1) * 100

                pred_keypoints_3d_vis = pred_keypoints_3d.cpu().numpy()
                pred_keypoints_3d_vis = np.reshape(
                    pred_keypoints_3d_vis,
                    (pred_keypoints_3d_vis.shape[0], -1))  #N,14x3
                pred_keypoints_3d_vis = np.swapaxes(pred_keypoints_3d_vis, 0,
                                                    1) * 100
                # output_sample = output_sample[ : , np.newaxis]*0.1
                # gt_sample = gt_sample[: , np.newaxis]*0.1
                # (skelNum, dim, frames)
                glViewer.setSkeleton(
                    [gt_keypoints_3d_vis, pred_keypoints_3d_vis],
                    jointType='smplcoco')  #(skelNum, dim, frames)
                glViewer.show()

            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()

            error_upper = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            # mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(),
                                           gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)

            r_error_upper = reconstruction_error(
                pred_keypoints_3d.cpu().numpy(),
                gt_keypoints_3d.cpu().numpy(),
                reduction=None)
            # recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            for ii, p in enumerate(batch['imgname'][:len(r_error)]):
                seqName = os.path.basename(os.path.dirname(p))
                # quant_mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error
                if seqName not in quant_mpjpe.keys():
                    quant_mpjpe[seqName] = []
                    quant_recon_err[seqName] = []

                quant_mpjpe[seqName].append(error[ii])
                quant_recon_err[seqName].append(r_error[ii])

            # Reconstuction_error
            # quant_recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            list_mpjpe = np.hstack([quant_mpjpe[k] for k in quant_mpjpe])
            list_reconError = np.hstack(
                [quant_recon_err[k] for k in quant_recon_err])
            if bVerbose:
                print(
                    ">>> {} : MPJPE {:.02f} mm, error: {:.02f} mm | Total MPJPE {:.02f} mm, error {:.02f} mm"
                    .format(seqName,
                            np.mean(error) * 1000,
                            np.mean(r_error) * 1000,
                            np.hstack(list_mpjpe).mean() * 1000,
                            np.hstack(list_reconError).mean() * 1000))

            # print("MPJPE {}, error: {}".format(np.mean(error)*100, np.mean(r_error)*100))

        # If mask or part evaluation, render the mask and part images
        # if eval_masks or eval_parts:
        #     mask, parts = renderer(pred_vertices, pred_camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i],
                                   orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(
                    os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8),
                                    center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(
                    os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if bVerbose:
            if step % log_freq == log_freq - 1:
                if eval_pose:
                    print('MPJPE: ' +
                          str(1000 * mpjpe[:step * batch_size].mean()))
                    print('Reconstruction Error: ' +
                          str(1000 * recon_err[:step * batch_size].mean()))
                    print()
                if eval_masks:
                    print('Accuracy: ', accuracy / pixel_count)
                    print('F1: ', f1.mean())
                    print()
                if eval_parts:
                    print('Parts Accuracy: ',
                          parts_accuracy / parts_pixel_count)
                    print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5,
                                                       6]].mean())
                    print()

        # if step==3:     #Debug
        #     break
    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
    # Print final results during evaluation

    if bVerbose:
        print('*** Final Results ***')
        print()
    if eval_pose:
        # if bVerbose:
        #     print('MPJPE: ' + str(1000 * mpjpe.mean()))
        #     print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        #     print()
        list_mpjpe = np.hstack([quant_mpjpe[k] for k in quant_mpjpe])
        list_reconError = np.hstack(
            [quant_recon_err[k] for k in quant_recon_err])

        output_str = 'SeqNames; '
        for seq in quant_mpjpe:
            output_str += seq + ';'
        output_str += '\n MPJPE; '
        quant_mpjpe_avg_mm = np.hstack(list_mpjpe).mean() * 1000
        output_str += "Avg {:.02f} mm; ".format(quant_mpjpe_avg_mm)
        for seq in quant_mpjpe:
            output_str += '{:.02f}; '.format(
                1000 * np.hstack(quant_mpjpe[seq]).mean())

        output_str += '\n Recon Error; '
        quant_recon_error_avg_mm = np.hstack(list_reconError).mean() * 1000
        output_str += "Avg {:.02f}mm; ".format(quant_recon_error_avg_mm)
        for seq in quant_recon_err:
            output_str += '{:.02f}; '.format(
                1000 * np.hstack(quant_recon_err[seq]).mean())
        if bVerbose:
            print(output_str)
        else:
            print(
                ">>>  Test on 3DPW: MPJPE: {} | quant_recon_error_avg_mm: {}".
                format(quant_mpjpe_avg_mm, quant_recon_error_avg_mm))

        #Export log to json
        if os.path.exists(
                "/private/home/hjoo/codes/fairmocap/benchmarks_eval"):
            targetFile = "evalFromPkl_" + os.path.basename(pklDir) + ".txt"
            targetFile = os.path.join(
                '/private/home/hjoo/codes/fairmocap/benchmarks_eval',
                targetFile)
            print(f"\n Writing output to:{targetFile}")
            text_file = open(targetFile, "w")
            text_file.write(output_str)
            text_file.write("\n Input Pkl Dir:{}".format(pklDir))
            text_file.close()

        return quant_mpjpe_avg_mm, quant_recon_error_avg_mm

    if bVerbose:
        if eval_masks:
            print('Accuracy: ', accuracy / pixel_count)
            print('F1: ', f1.mean())
            print()
        if eval_parts:
            print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
            print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
            print()

    return -1  #Should return something
Exemplo n.º 14
0
def run_evaluation(model,
                   dataset_name,
                   dataset,
                   result_file,
                   batch_size=32,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(config.SMPL_MODEL_DIR, create_transl=False).to(device)
    smpl_male = SMPL(config.SMPL_MODEL_DIR, gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)

    renderer = PartRenderer()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(
        config.JOINT_REGRESSOR_H36M)).float()

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas,
                                   body_pose=gt_pose[:, 3:],
                                   global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]

        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0, 0, 1],
                                   dtype=torch.float32,
                                   device=device).view(1, 3, 1)
            rotmat = torch.cat((pred_rotmat.view(
                -1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)),
                               dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(
                rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size +
                      curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size +
                       curr_batch_size, :] = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size +
                        curr_batch_size, :] = pred_camera.cpu().numpy()

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:, :3],
                                        body_pose=gt_pose[:, 3:],
                                        betas=gt_betas).vertices
                gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3],
                                                 body_pose=gt_pose[:, 3:],
                                                 betas=gt_betas).vertices
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender ==
                                                                    1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[
                    step * batch_size:step * batch_size +
                    curr_batch_size, :, :] = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis
            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size +
                  curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(),
                                           gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)
            recon_err[step * batch_size:step * batch_size +
                      curr_batch_size] = r_error

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' +
                      str(1000 * recon_err[:step * batch_size].mean()))
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
        np.savez('error.npz', mpjpe=mpjpe, recon_err=recon_err)
    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
Exemplo n.º 15
0
                        global_orient=pred_rotmat[:, 0].unsqueeze(1),
                        pose2rot=False)
     pred_vertices = pred_output.vertices
 camera_translation = torch.stack([
     pred_camera[:, 1], pred_camera[:, 2], 2 * constants.FOCAL_LENGTH /
     (constants.IMG_RES * pred_camera[:, 0] + 1e-9)
 ],
                                  dim=-1)
 camera_translation = camera_translation[0].cpu().numpy()
 pred_vertices = pred_vertices[0].cpu().numpy()
 #Convert rotation matrix of joint points to rotate vector
 rot_pad = torch.tensor([0, 0, 1], dtype=torch.float32,
                        device=device).view(1, 3, 1)
 rotmat = torch.cat(
     (pred_rotmat.view(-1, 3, 3), rot_pad.expand(1 * 24, -1, -1)), dim=-1)
 pred_pose = tgm.rotation_matrix_to_angle_axis(rotmat).contiguous().view(
     -1, 72)
 pred_theta = pred_pose.cpu().numpy()
 pred_beta = pred_betas.cpu().numpy()
 pred_param = {'pose': pred_theta, 'shape': pred_beta}
 print(pred_theta.shape, pred_beta.shape)
 img = img.permute(1, 2, 0).cpu().numpy()
 # Rendered result
 img_shape = renderer(pred_vertices, camera_translation, img)
 aroundy = cv2.Rodrigues(np.array([0, np.radians(90.), 0]))[0]
 center = pred_vertices.mean(axis=0)
 rot_vertices = np.dot((pred_vertices - center), aroundy) + center
 # The other side result
 img_shape_side = renderer(rot_vertices, camera_translation,
                           np.ones_like(img))
 # Output filename
 outfile = args.test_image.split(
Exemplo n.º 16
0
    def train_step(self, input_batch):
        self.model.train()
        # get data from batch
        has_smpl = input_batch['has_smpl'].bool()
        has_pose_3d = input_batch['has_pose_3d'].bool()
        gt_pose1 = input_batch['pose']  # SMPL pose parameters
        gt_betas1 = input_batch['betas']  # SMPL beta parameters
        dataset_name = input_batch['dataset_name']
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        #print(rot_angle)
        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_betas = torch.cat((gt_betas1, gt_betas1, gt_betas1, gt_betas1), 0)
        gt_pose = torch.cat((gt_pose1, gt_pose1, gt_pose1, gt_pose1), 0)
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices
        # Get current best fits from the dictionary
        opt_pose1, opt_betas1 = self.fits_dict[(dataset_name, indices.cpu(),
                                                rot_angle.cpu(),
                                                is_flipped.cpu())]
        opt_pose = torch.cat(
            (opt_pose1.to(self.device), opt_pose1.to(self.device),
             opt_pose1.to(self.device), opt_pose1.to(self.device)), 0)
        #print(opt_pose.device)
        #opt_betas = opt_betas.to(self.device)
        opt_betas = torch.cat(
            (opt_betas1.to(self.device), opt_betas1.to(self.device),
             opt_betas1.to(self.device), opt_betas1.to(self.device)), 0)
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints
        # images
        images = torch.cat((input_batch['img_0'], input_batch['img_1'],
                            input_batch['img_2'], input_batch['img_3']), 0)
        batch_size = input_batch['img_0'].shape[0]
        #input()
        # Output of CNN
        pred_rotmat, pred_betas, pred_camera = self.model(images)
        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)
        camera_center = torch.zeros(batch_size * 4, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size * 4, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)
        # 2d joint points
        gt_keypoints_2d = torch.cat(
            (input_batch['keypoints_0'], input_batch['keypoints_1'],
             input_batch['keypoints_2'], input_batch['keypoints_3']), 0)
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)
        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)
        #input()
        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res *
            torch.ones(batch_size * 4, 2, device=self.device),
            gt_keypoints_2d_orig).mean(dim=-1)
        if self.options.run_smplify:
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 4 * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size * 4, -1)
            pred_pose[torch.isnan(pred_pose)] = 0.0
            #pred_pose_detach = pred_pose.detach()
            #pred_betas_detach = pred_betas.detach()
            #pred_cam_t_detach = pred_cam_t.detach()
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size*4, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)
            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)
            update1 = torch.cat((update, update, update, update), 0)
            opt_joint_loss[update] = new_opt_joint_loss[update]
            #print(opt_joints.size(),new_opt_joints.size())
            #input()
            opt_joints[update1, :] = new_opt_joints[update1, :]
            #print(opt_pose.size(),new_opt_pose.size())
            opt_betas[update1, :] = new_opt_betas[update1, :]
            opt_pose[update1, :] = new_opt_pose[update1, :]
            #print(i, opt_pose_mv[i])
            opt_vertices[update1, :] = new_opt_vertices[update1, :]
            opt_cam_t[update1, :] = new_opt_cam_t[update1, :]
        # now we comput the loss on the four images
        # Replace the optimized parameters with the ground truth parameters, if available
        #for i in range(4):
        #print('Here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1')
        has_smpl1 = torch.cat((has_smpl, has_smpl, has_smpl, has_smpl), 0)
        opt_vertices[has_smpl1, :, :] = gt_vertices[has_smpl1, :, :]
        opt_pose[has_smpl1, :] = gt_pose[has_smpl1, :]
        opt_cam_t[has_smpl1, :] = gt_cam_t[has_smpl1, :]
        opt_joints[has_smpl1, :, :] = gt_model_joints[has_smpl1, :, :]
        opt_betas[has_smpl1, :] = gt_betas[has_smpl1, :]
        #print(opt_cam_t[0:batch_size],opt_cam_t[batch_size:2*batch_size],opt_cam_t[2*batch_size:3*batch_size],opt_cam_t[3*batch_size:4*batch_size])
        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit1 = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)
        # Add the examples with GT parameters to the list of valid fits
        valid_fit = torch.cat(
            (valid_fit1, valid_fit1, valid_fit1, valid_fit1), 0) | has_smpl1

        #gt_keypoints_2d = torch.cat((input_batch['keypoints_0'],input_batch['keypoints_1'],input_batch['keypoints_2'],input_batch['keypoints_3']),0)
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            0, 1)
        #gt_joints = torch.cat((input_batch['pose_3d_0'],input_batch['pose_3d_1'],input_batch['pose_3d_2'],input_batch['pose_3d_3']),0)
        #loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, torch.cat((has_pose_3d,has_pose_3d,has_pose_3d,has_pose_3d),0))
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)
        #print(loss_shape_sum,loss_keypoints_sum,loss_keypoints_3d_sum,loss_regr_pose_sum,loss_regr_betas_sum)
        #input()
        loss_all = 0 * loss_shape +\
                   5. * loss_keypoints +\
                   0. * loss_keypoints_3d +\
                   loss_regr_pose + 0.001* loss_regr_betas +\
                   ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        loss_all *= 60
        #print(loss_all)

        # Do backprop
        self.optimizer.zero_grad()
        loss_all.backward()
        self.optimizer.step()
        output = {
            'pred_vertices': pred_vertices,
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t,
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss_all.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        return output, losses
Exemplo n.º 17
0
def run_evaluation(model, dataset):
    """Run evaluation on the datasets and metrics we report in the paper. """

    shuffle = args.shuffle
    log_freq = args.log_freq
    batch_size = args.batch_size
    dataset_name = args.dataset
    result_file = args.result_file
    num_workers = args.num_workers
    device = torch.device('cuda') if torch.cuda.is_available() \
                                else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)

    renderer = PartRenderer()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(
        path_config.JOINT_REGRESSOR_H36M)).float()

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))
    pve = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))
    action_idxes = {}
    idx_counter = 0
    # for each action
    act_PVE = {}
    act_MPJPE = {}
    act_paMPJPE = {}

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == 'h36m-p2-mosh' \
       or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp' or dataset_name == '3doh50k':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = path_config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    cnt = 0
    results_dict = {'id': [], 'pred': [], 'pred_pa': [], 'gt': []}
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_smpl_out = smpl_neutral(betas=gt_betas,
                                   body_pose=gt_pose[:, 3:],
                                   global_orient=gt_pose[:, :3])
        gt_vertices_nt = gt_smpl_out.vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]

        if save_results:
            s_id = np.array(
                [int(item.split('/')[-3][-1])
                 for item in batch['imgname']]) * 10000
            s_id += np.array(
                [int(item.split('/')[-1][4:-4]) for item in batch['imgname']])
            results_dict['id'].append(s_id)

        if dataset_name == 'h36m-p2':
            action = [
                im_path.split('/')[-1].split('.')[0].split('_')[1]
                for im_path in batch['imgname']
            ]
            for act_i in range(len(action)):

                if action[act_i] in action_idxes:
                    action_idxes[action[act_i]].append(idx_counter + act_i)
                else:
                    action_idxes[action[act_i]] = [idx_counter + act_i]
            idx_counter += len(action)

        with torch.no_grad():
            if args.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
                # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
            elif args.regressor == 'pymaf_net':
                preds_dict, _ = model(images)
                pred_rotmat = preds_dict['smpl_out'][-1]['rotmat'].contiguous(
                ).view(-1, 24, 3, 3)
                pred_betas = preds_dict['smpl_out'][-1][
                    'theta'][:, 3:13].contiguous()
                pred_camera = preds_dict['smpl_out'][-1][
                    'theta'][:, :3].contiguous()

            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0, 0, 1],
                                   dtype=torch.float32,
                                   device=device).view(1, 3, 1)
            rotmat = torch.cat((pred_rotmat.view(
                -1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)),
                               dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(
                rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size +
                      curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size +
                       curr_batch_size, :] = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size +
                        curr_batch_size, :] = pred_camera.cpu().numpy()

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name or '3doh50k' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:, :3],
                                        body_pose=gt_pose[:, 3:],
                                        betas=gt_betas).vertices
                gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3],
                                                 body_pose=gt_pose[:, 3:],
                                                 betas=gt_betas).vertices
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender ==
                                                                    1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            if '3dpw' in dataset_name:
                per_vertex_error = torch.sqrt(
                    ((pred_vertices -
                      gt_vertices)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            else:
                per_vertex_error = torch.sqrt(
                    ((pred_vertices - gt_vertices_nt)**2).sum(dim=-1)).mean(
                        dim=-1).cpu().numpy()
            pve[step * batch_size:step * batch_size +
                curr_batch_size] = per_vertex_error

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[
                    step * batch_size:step * batch_size +
                    curr_batch_size, :, :] = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size +
                  curr_batch_size] = error

            # Reconstuction_error
            r_error, pred_keypoints_3d_pa = reconstruction_error(
                pred_keypoints_3d.cpu().numpy(),
                gt_keypoints_3d.cpu().numpy(),
                reduction=None)
            recon_err[step * batch_size:step * batch_size +
                      curr_batch_size] = r_error

            if save_results:
                results_dict['gt'].append(gt_keypoints_3d.cpu().numpy())
                results_dict['pred'].append(pred_keypoints_3d.cpu().numpy())
                results_dict['pred_pa'].append(pred_keypoints_3d_pa)

        if args.vis_demo:
            imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]

            if args.regressor == 'hmr':
                iuv_pred = None

            images_vis = images * torch.tensor([0.229, 0.224, 0.225],
                                               device=images.device).reshape(
                                                   1, 3, 1, 1)
            images_vis = images_vis + torch.tensor(
                [0.485, 0.456, 0.406], device=images.device).reshape(
                    1, 3, 1, 1)
            vis_smpl_iuv(
                images_vis.cpu().numpy(),
                pred_camera.cpu().numpy(),
                pred_output.vertices.cpu().numpy(), smpl_neutral.faces,
                iuv_pred, 100 * per_vertex_error, imgnames,
                os.path.join('./notebooks/output/demo_results', dataset_name,
                             args.checkpoint.split('/')[-3]), args)

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)
        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i],
                                   orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(
                    os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8),
                                    center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(
                    os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' +
                      str(1000 * recon_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5,
                                                   6]].mean())
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
        for k in results_dict.keys():
            results_dict[k] = np.concatenate(results_dict[k])
            print(k, results_dict[k].shape)

        scipy.io.savemat(result_file + '.mat', results_dict)

    # Print final results during evaluation
    print('*** Final Results ***')
    try:
        print(os.path.split(args.checkpoint)[-3:], args.dataset)
    except:
        pass
    if eval_pose:
        print('PVE: ' + str(1000 * pve.mean()))
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
        print()

    if dataset_name == 'h36m-p2':
        print(
            'Note: PVE is not available for h36m-p2. To evaluate PVE, use h36m-p2-mosh instead.'
        )
        for act in action_idxes:
            act_idx = action_idxes[act]
            act_pve = [pve[i] for i in act_idx]
            act_errors = [mpjpe[i] for i in act_idx]
            act_errors_pa = [recon_err[i] for i in act_idx]

            act_errors_mean = np.mean(np.array(act_errors)) * 1000.
            act_errors_pa_mean = np.mean(np.array(act_errors_pa)) * 1000.
            act_pve_mean = np.mean(np.array(act_pve)) * 1000.
            act_MPJPE[act] = act_errors_mean
            act_paMPJPE[act] = act_errors_pa_mean
            act_PVE[act] = act_pve_mean

        act_err_info = ['action err']
        act_row = [str(act_paMPJPE[act])
                   for act in action_idxes] + [act for act in action_idxes]
        act_err_info.extend(act_row)
        print(act_err_info)
    else:
        act_row = None
Exemplo n.º 18
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].byte(
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        )  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary
        opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
                                              rot_angle.cpu(),
                                              is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        if opt_vertices.shape != (self.options.batch_size, 6890, 3):
            opt_vertices = torch.zeros_like(opt_vertices, device=self.device)
        opt_joints = opt_output.joints

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res *
            torch.ones(batch_size, 2, device=self.device),
            gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        pred_vertices = pred_output.vertices
        if pred_vertices.shape != (self.options.batch_size, 6890, 3):
            pred_vertices = torch.zeros_like(pred_vertices, device=self.device)

        pred_joints = pred_output.joints

        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)

        camera_center = torch.zeros(batch_size, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        # Normalize keypoints to [-1,1]
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        if self.options.run_smplify:

            # Convert predicted rotation matrices to axis-angle
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size, -1)
            # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
            pred_pose[torch.isnan(pred_pose)] = 0.0

            # Run SMPLify optimization starting from the network prediction
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)

            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)

            opt_joint_loss[update] = new_opt_joint_loss[update]
            opt_vertices[update, :] = new_opt_vertices[update, :]
            opt_joints[update, :] = new_opt_joints[update, :]
            opt_pose[update, :] = new_opt_pose[update, :]
            opt_betas[update, :] = new_opt_betas[update, :]
            opt_cam_t[update, :] = new_opt_cam_t[update, :]

            self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(),
                            is_flipped.cpu(),
                            update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())

        else:
            update = torch.zeros(batch_size, device=self.device).byte()

        # Replace extreme betas with zero betas
        opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)
        # Add the examples with GT parameters to the list of valid fits
        # print(valid_fit.dtype)
        valid_fit = valid_fit.to(torch.uint8)
        valid_fit = valid_fit | has_smpl

        opt_keypoints_2d = perspective_projection(
            opt_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=opt_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)

        opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)

        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        # Compute 3D keypoint loss
        loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints,
                                                  has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)

        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        loss = self.options.shape_loss_weight * loss_shape +\
               self.options.keypoint_loss_weight * loss_keypoints +\
               self.options.keypoint_loss_weight * loss_keypoints_3d +\
               loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
               ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()
        loss *= 60

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        return output, losses
Exemplo n.º 19
0
def exportOursToSpin(eftDir, out_path):

    # scaleFactor = 1.2

    # structs we need
    imgnames_, scales_, centers_, parts_, openposes_ = [], [], [], [], []

    #additional 3D
    poses_, shapes_, skel3D_, has_smpl_ = [], [], [], []

    pose3DList = os.listdir(eftDir)

    # for imgSample in cocoPose3DAll:
    sampleNum = len(pose3DList)
    # totalSampleNum = [ len(cocoPose3DAll[imgSample]) for imgSample in cocoPose3DAll ]
    # totalSampleNum = sum(totalSampleNum)
    print("\n\n### SampleNum: {} ###".format(sampleNum))

    maxDiff = 0
    for fname in tqdm(sorted(pose3DList)):

        fname_path = os.path.join(eftDir, fname)

        pose3d = pickle.load(open(fname_path, 'rb'))

        #load image
        imgPathFull = pose3d['imageName'][0]
        fileName = os.path.basename(imgPathFull)
        fileName_saved = os.path.join(
            os.path.basename(os.path.dirname(imgPathFull)),
            fileName)  #start from train2014
        center = pose3d['center'][0]
        scale = pose3d['scale'][0]

        smpl_shape = pose3d['pred_shape'].ravel()
        smpl_pose_mat = torch.from_numpy(
            pose3d['pred_pose_rotmat'][0])  #24,3,3
        pred_rotmat_hom = torch.cat([
            smpl_pose_mat.view(-1, 3, 3),
            torch.tensor(
                [0, 0, 0],
                dtype=torch.float32,
            ).view(1, 3, 1).expand(24, -1, -1)
        ],
                                    dim=-1)
        smpl_pose = tgm.rotation_matrix_to_angle_axis(
            pred_rotmat_hom).contiguous().view(-1, 72)

        #verification
        if True:
            recon_mat = batch_rodrigues(smpl_pose.view(
                -1, 3))  #24,3... axis -> rotmat
            diff = abs(recon_mat.numpy() -
                       pose3d['pred_pose_rotmat'][0])  #2.1234155e-07
            # print(np.max(diff))
            maxDiff = max(maxDiff, np.max(diff))

        smpl_pose = smpl_pose.numpy().ravel()

        openpose2d = pose3d['keypoint2d'][0][:25]  #25,3
        spin2d_skel24 = pose3d['keypoint2d'][0][25:]  #24,3

        #Save data
        imgnames_.append(fileName_saved)
        centers_.append(center)
        scales_.append(scale)
        has_smpl_.append(1)
        poses_.append(smpl_pose)  #(72,)
        shapes_.append(smpl_shape)  #(10,)

        openposes_.append(openpose2d)  #blank
        # print(openpose2d)/
        parts_.append(spin2d_skel24)

        #3D joint
        S = np.zeros(
            [24, 4])  #blank for 3d. TODO: may need to add valid data for this
        skel3D_.append(S)

        #Debug 2D Visualize
        if False:
            img = cv2.imread(
                os.path.join('/run/media/hjoo/disk/data/coco', imgnames_[-1]))
            img = viewer2D.Vis_Skeleton_2D_smplCOCO(
                gt_skel, pt2d_visibility=gt_validity[:, 0], image=img)
            img = viewer2D.Vis_Bbox_minmaxPt(img, min_pt, max_pt)
            viewer2D.ImShow(img, waitTime=0)

        #Debug 3D Visualize smpl_coco
        if False:
            # data3D_coco_vis = np.reshape(data3D_coco, (data3D_coco.shape[0],-1)).transpose()   #(Dim, F)
            # data3D_coco_vis *=0.1   #mm to cm
            # glViewer.setSkeleton( [ data3D_coco_vis] ,jointType='smplcoco')
            # glViewer.show()

            #Debug 3D Visualize, h36m
            data3D_h36m_vis = np.reshape(
                data3D_h36m, (data3D_h36m.shape[0], -1)).transpose()  #(Dim, F)
            data3D_h36m_vis *= 100  #meter to cm

            # data3D_smpl24 = np.reshape(data3D_smpl24, (data3D_smpl24.shape[0],-1)).transpose()   #(Dim, F)
            # data3D_smpl24 *=0.1

            glViewer.setSkeleton([data3D_h36m_vis], jointType='smplcoco')
            glViewer.show()

        # keypoints

    # print("Final Img Num: {}, Final Sample Num: {}".format( len(set(imgnames_) , len(imgnames_)) ) )
    print("Final Sample Num: {}".format(len(imgnames_)))
    print("maxDiff in rot conv.: {}".format(maxDiff))
    # store the data struct
    if not os.path.isdir(out_path):
        os.makedirs(out_path)
    out_file = os.path.join(out_path, os.path.basename(eftDir) + '.npz')

    print(f"Save to {out_file}")
    np.savez(out_file,
             imgname=imgnames_,
             center=centers_,
             scale=scales_,
             part=parts_,
             openpose=openposes_,
             pose=poses_,
             shape=shapes_,
             has_smpl=has_smpl_,
             S=skel3D_)
Exemplo n.º 20
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch[
            'keypoints']  # 2D keypoints           #[N,49,3]
        gt_pose = input_batch[
            'pose']  # SMPL pose parameters                #[N,72]
        gt_betas = input_batch[
            'betas']  # SMPL beta parameters              #[N,10]
        gt_joints = input_batch[
            'pose_3d']  # 3D pose                        #[N,24,4]
        has_smpl = input_batch['has_smpl'].byte(
        ) == 1  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        ) == 1  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        #Debug temporary scaling for h36m
        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])

        gt_model_joints = gt_out.joints.detach()  #[N, 49, 3]
        gt_vertices = gt_out.vertices

        # else:
        #     gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:,3:-6], global_orient=gt_pose[:,:3])

        #     gt_model_joints = gt_out.joints.detach()             #[N, 49, 3]
        #     gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary

        opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name,
                                                            indices.cpu(),
                                                            rot_angle.cpu(),
                                                            is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        # if g_smplx == False:
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])

        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints.detach()

        # else:
        #     opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:,3:-6], global_orient=opt_pose[:,:3])

        #     opt_vertices = opt_output.vertices
        #     opt_joints = opt_output.joints.detach()

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity == 0]) > 0:
            assert min(has_smpl[opt_validity == 0])  #All should be True

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity == 0]) > 0:
            assert min(has_smpl[opt_validity == 0])  #All should be True

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose,
            opt_betas,
            opt_cam_t,  #opt_pose (N,72)  (N,10)  opt_cam_t: (N,3)
            0.5 * self.options.img_res *
            torch.ones(batch_size, 2, device=self.device),  #(N,2)   (112, 112)
            gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        # if g_smplx == False: #Original
        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        # else:
        #     pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:-2], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)

        camera_center = torch.zeros(batch_size, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        # Normalize keypoints to [-1,1]
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        #Weak Projection
        if self.options.bUseWeakProj:
            pred_keypoints_2d = weakProjection_gpu(pred_joints, pred_camera[:,
                                                                            0],
                                                   pred_camera[:,
                                                               1:])  #N, 49, 2

        bFootOriLoss = False
        if bFootOriLoss:  #Ignore hips and hip centers, foot
            # LENGTH_THRESHOLD = 0.0089 #1/112.0     #at least it should be 5 pixel
            #Disable parts
            gt_keypoints_2d[:, 2 + 25, 2] = 0
            gt_keypoints_2d[:, 3 + 25, 2] = 0
            gt_keypoints_2d[:, 14 + 25, 2] = 0

            #Disable Foots
            gt_keypoints_2d[:, 5 + 25, 2] = 0  #Left foot
            gt_keypoints_2d[:, 0 + 25, 2] = 0  #Right foot

        if self.options.run_smplify:

            # Convert predicted rotation matrices to axis-angle
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size, -1)
            # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
            pred_pose[torch.isnan(pred_pose)] = 0.0

            # Run SMPLify optimization starting from the network prediction
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)

            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)
            # print("new_opt_joint_loss{} vs opt_joint_loss{}".format(new_opt_joint_loss))

            if True:  #Visualize opt
                for b in range(batch_size):

                    curImgVis = images[b]  #3,224,224
                    curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                    curImgVis = np.transpose(curImgVis, (1, 2, 0)) * 255.0
                    curImgVis = curImgVis[:, :, [2, 1, 0]]

                    #Denormalize image
                    curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                    viewer2D.ImShow(curImgVis, name='rawIm')
                    originalImg = curImgVis.copy()

                    pred_camera_vis = pred_camera.detach().cpu().numpy()

                    opt_vert_vis = opt_vertices[b].detach().cpu().numpy()
                    opt_vert_vis *= pred_camera_vis[b, 0]
                    opt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis *= 112
                    opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}

                    gt_vert_vis = gt_vertices[b].detach().cpu().numpy()
                    gt_vert_vis *= pred_camera_vis[b, 0]
                    gt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis *= 112
                    gt_meshes = {'ver': gt_vert_vis, 'f': self.smpl.faces}

                    new_opt_output = self.smpl(
                        betas=new_opt_betas,
                        body_pose=new_opt_pose[:, 3:],
                        global_orient=new_opt_pose[:, :3])
                    new_opt_vertices = new_opt_output.vertices
                    new_opt_joints = new_opt_output.joints
                    new_opt_vert_vis = new_opt_vertices[b].detach().cpu(
                    ).numpy()
                    new_opt_vert_vis *= pred_camera_vis[b, 0]
                    new_opt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis *= 112
                    new_opt_meshes = {
                        'ver': new_opt_vert_vis,
                        'f': self.smpl.faces
                    }

                    glViewer.setMeshData(
                        [new_opt_meshes, gt_meshes, new_opt_meshes],
                        bComputeNormal=True)

                    glViewer.setBackgroundTexture(originalImg)
                    glViewer.setWindowSize(curImgVis.shape[1],
                                           curImgVis.shape[0])
                    glViewer.SetOrthoCamera(True)

                    print(has_smpl[b])
                    glViewer.show()

            opt_joint_loss[update] = new_opt_joint_loss[update]
            opt_vertices[update, :] = new_opt_vertices[update, :]
            opt_joints[update, :] = new_opt_joints[update, :]
            opt_pose[update, :] = new_opt_pose[update, :]
            opt_betas[update, :] = new_opt_betas[update, :]
            opt_cam_t[update, :] = new_opt_cam_t[update, :]

            self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(),
                            is_flipped.cpu(),
                            update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())

        else:
            update = torch.zeros(batch_size, device=self.device).byte()

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)

        if self.options.ablation_no_pseudoGT:
            valid_fit[:] = False  #Disable all pseudoGT

        # Add the examples with GT parameters to the list of valid fits
        valid_fit = valid_fit | has_smpl

        # if len(valid_fit) > sum(valid_fit):
        #     print(">> Rejected fit: {}/{}".format(len(valid_fit) - sum(valid_fit), len(valid_fit) ))

        opt_keypoints_2d = perspective_projection(
            opt_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=opt_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)

        opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)

        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        # Compute 3D keypoint loss
        loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints,
                                                  has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)

        #Regularization term for shape
        loss_regr_betas_noReject = torch.mean(pred_betas**2)

        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        if self.options.ablation_loss_2dkeyonly:  #2D keypoint only
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                    self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        elif self.options.ablation_loss_noSMPLloss:  #2D no Pose parameter
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        else:
            loss = self.options.shape_loss_weight * loss_shape +\
                self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        # loss = self.options.keypoint_loss_weight * loss_keypoints #Debug: 2d error only
        # print("DEBUG: 2donly loss")
        loss *= 60

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        if self.options.bDebug_visEFT:  #g_debugVisualize:    #Debug Visualize input
            for b in range(batch_size):
                #denormalizeImg
                curImgVis = images[b]  #3,224,224
                curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                curImgVis = np.transpose(curImgVis, (1, 2, 0)) * 255.0
                curImgVis = curImgVis[:, :, [2, 1, 0]]

                #Denormalize image
                curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                viewer2D.ImShow(curImgVis, name='rawIm')
                originalImg = curImgVis.copy()

                # curImgVis = viewer2D.Vis_Skeleton_2D_general(gt_keypoints_2d_orig[b,:,:2].cpu().numpy(), gt_keypoints_2d_orig[b,:,2], bVis= False, image=curImgVis)

                pred_keypoints_2d_vis = pred_keypoints_2d[
                    b, :, :2].detach().cpu().numpy()
                pred_keypoints_2d_vis = 0.5 * self.options.img_res * (
                    pred_keypoints_2d_vis + 1)  #49: (25+24) x 3

                curImgVis = viewer2D.Vis_Skeleton_2D_general(
                    pred_keypoints_2d_vis, bVis=False, image=curImgVis)
                viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)

                #Get camera pred_params
                pred_camera_vis = pred_camera.detach().cpu().numpy()

                ############### Visualize Mesh ###############
                pred_vert_vis = pred_vertices[b].detach().cpu().numpy()
                # meshVertVis = gt_vertices[b].detach().cpu().numpy()
                # meshVertVis = meshVertVis-pelvis        #centering
                pred_vert_vis *= pred_camera_vis[b, 0]
                pred_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis *= 112
                pred_meshes = {'ver': pred_vert_vis, 'f': self.smpl.faces}

                opt_vert_vis = opt_vertices[b].detach().cpu().numpy()
                opt_vert_vis *= pred_camera_vis[b, 0]
                opt_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis *= 112
                opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}

                # glViewer.setMeshData([pred_meshes, opt_meshes], bComputeNormal= True)
                glViewer.setMeshData([pred_meshes, opt_meshes],
                                     bComputeNormal=True)
                # glViewer.setMeshData([opt_meshes], bComputeNormal= True)

                ############### Visualize Skeletons ###############
                #Vis pred-SMPL joint
                pred_joints_vis = pred_joints[
                    b, :, :3].detach().cpu().numpy()  #[N,49,3]
                pred_joints_vis = pred_joints_vis.ravel()[:, np.newaxis]
                #Weak-perspective projection
                pred_joints_vis *= pred_camera_vis[b, 0]
                pred_joints_vis[::3] += pred_camera_vis[b, 1]
                pred_joints_vis[1::3] += pred_camera_vis[b, 2]
                pred_joints_vis *= 112  #112 == 0.5*224
                glViewer.setSkeleton([pred_joints_vis])

                # #GT joint
                gt_jointsVis = gt_joints[b, :, :3].cpu().numpy()  #[N,49,3]
                # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                gt_jointsVis = gt_jointsVis.ravel()[:, np.newaxis]
                gt_jointsVis *= pred_camera_vis[b, 0]
                gt_jointsVis[::3] += pred_camera_vis[b, 1]
                gt_jointsVis[1::3] += pred_camera_vis[b, 2]
                gt_jointsVis *= 112
                glViewer.addSkeleton([gt_jointsVis], jointType='spin')

                # #Vis SMPL's Skeleton
                # gt_smpljointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                # # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                # gt_smpljointsVis = gt_smpljointsVis.ravel()[:,np.newaxis]
                # gt_smpljointsVis*=pred_camera_vis[b,0]
                # gt_smpljointsVis[::3] += pred_camera_vis[b,1]
                # gt_smpljointsVis[1::3] += pred_camera_vis[b,2]
                # gt_smpljointsVis*=112
                # glViewer.addSkeleton( [gt_smpljointsVis])

                # #Vis GT  joint  (not model (SMPL) joint!!)
                # if has_pose_3d[b]:
                #     gt_jointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_jointsVis = gt_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_pelvis = (gt_jointsVis[ 25+2,:] + gt_jointsVis[ 25+3,:]) / 2
                #     # gt_jointsVis = gt_jointsVis- gt_pelvis

                #     gt_jointsVis = gt_jointsVis.ravel()[:,np.newaxis]
                #     gt_jointsVis*=pred_camera_vis[b,0]
                #     gt_jointsVis[::3] += pred_camera_vis[b,1]
                #     gt_jointsVis[1::3] += pred_camera_vis[b,2]
                #     gt_jointsVis*=112

                #     glViewer.addSkeleton( [gt_jointsVis])
                # # glViewer.show()

                glViewer.setBackgroundTexture(originalImg)
                glViewer.setWindowSize(curImgVis.shape[1], curImgVis.shape[0])
                glViewer.SetOrthoCamera(True)
                glViewer.show(0)

                # continue

        return output, losses
Exemplo n.º 21
0
                spin_betas = torch.from_numpy(data['spin_beta'])
                spin_pose = torch.from_numpy(data['spin_pose'])

            pred_camera_vis = data['pred_camera']
            keypoint2d_49 = data['keypoint2d']

            pred_rotmat_hom = torch.cat([
                ours_pose_rotmat.view(-1, 3, 3),
                torch.tensor(
                    [0, 0, 1],
                    dtype=torch.float32,
                ).view(1, 3, 1).expand(ours_pose_rotmat.shape[0] * 24, -1, -1)
            ],
                                        dim=-1)
            ours_pose_aa = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(ours_pose_rotmat.shape[0],
                                                   -1)
            # ours_pose_aa = pred_aa.cpu().numpy()

            if np.isnan(np.max(ours_pose_aa.numpy())):
                print("Warning: !!NAN detected!!!: {}".format(imgName))
                continue

            #Visualize SMPL output
            # Note that gt_model_joints is different from gt_joints as it comes from SMPL
            # ours_output = smpl(betas=ours_betas, body_pose=ours_pose_rotmat[:,1:], global_orient=ours_pose_rotmat[:,0].unsqueeze(1), pose2rot=False )
            ours_output = smpl(betas=ours_betas,
                               body_pose=ours_pose_aa[:, 3:],
                               global_orient=ours_pose_aa[:, :3])

            ours_joints_3d = ours_output.joints.detach().cpu().numpy()