Example #1
0
def train_data(dataset_path, openpose_path, out_path, joints_idx, scaleFactor, extract_img=False, fits_3d=None):

    joints17_idx = [4, 18, 19, 20, 23, 24, 25, 3, 5, 6, 7, 9, 10, 11, 14, 15, 16]

    h, w = 2048, 2048
    imgnames_, scales_, centers_ = [], [], []
    parts_, Ss_, openposes_ = [], [], []
    S17s_ =[]

    # training data
    user_list = range(1,9)
    seq_list = range(1,3)
    vid_list = list(range(3)) + list(range(4,9))

    counter = 0

    for user_i in user_list:
        for seq_i in seq_list:
            seq_path = os.path.join(dataset_path,
                                    'S' + str(user_i),
                                    'Seq' + str(seq_i))
            # mat file with annotations
            annot_file = os.path.join(seq_path, 'annot.mat')
            annot2 = sio.loadmat(annot_file)['annot2']      #  
            annot3 = sio.loadmat(annot_file)['annot3']
            # calibration file and camera parameters
            calib_file = os.path.join(seq_path, 'camera.calibration')
            Ks, Rs, Ts = read_calibration(calib_file, vid_list)

            for j, vid_i in enumerate(vid_list):

                # image folder
                imgs_path = os.path.join(seq_path,    

                                         'imageFrames',
                                         'video_' + str(vid_i))

                # extract frames from video file
                if extract_img:

                    # if doesn't exist
                    if not os.path.isdir(imgs_path):
                        os.makedirs(imgs_path)

                    # video file
                    vid_file = os.path.join(seq_path,
                                            'imageSequence',
                                            'video_' + str(vid_i) + '.avi')
                    vidcap = cv2.VideoCapture(vid_file)

                    # process video
                    frame = 0
                    while 1:
                        # extract all frames
                        success, image = vidcap.read()
                        if not success:
                            break
                        frame += 1
                        # image name
                        imgname = os.path.join(imgs_path,
                            'frame_%06d.jpg' % frame)
                        # save image
                        cv2.imwrite(imgname, image)

                # per frame
                cam_aa = cv2.Rodrigues(Rs[j])[0].T[0]
                pattern = os.path.join(imgs_path, '*.jpg')
                img_list = glob.glob(pattern)
                for i, img_i in enumerate(img_list):

                    # for each image we store the relevant annotations
                    img_name = img_i.split('/')[-1]
                    img_view = os.path.join('S' + str(user_i),
                                            'Seq' + str(seq_i),
                                            'imageFrames',
                                            'video_' + str(vid_i),
                                            img_name)
                    joints = np.reshape(annot2[vid_i][0][i], (28, 2))[joints17_idx]         #2D joint, 17,2
                    S17 = np.reshape(annot3[vid_i][0][i], (28, 3))/1000                    #3D joint. 28, 3
                    S17 = S17[joints17_idx] - S17[4] # 4 is the root                    #MPI_INOF -> 17,3

                    if False:
                        
                        skelVis = np.reshape(S17.ravel(), (-1,1) )*100
                        glViewer.setSkeleton(skelVis)
                        glViewer.show(5)
                        continue


                    bbox = [min(joints[:,0]), min(joints[:,1]),
                            max(joints[:,0]), max(joints[:,1])]
                    center = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2]
                    scale = scaleFactor*max(bbox[2]-bbox[0], bbox[3]-bbox[1])/200

                    # check that all joints are visible
                    x_in = np.logical_and(joints[:, 0] < w, joints[:, 0] >= 0)
                    y_in = np.logical_and(joints[:, 1] < h, joints[:, 1] >= 0)
                    ok_pts = np.logical_and(x_in, y_in)
                    if np.sum(ok_pts) < len(joints_idx):
                        continue
                        
                    part = np.zeros([24,3])
                    part[joints_idx] = np.hstack([joints, np.ones([17,1])])
                    json_file = os.path.join(openpose_path,
                        img_view.replace('.jpg', '_keypoints.json'))
                    openpose = read_openpose(json_file, part, 'mpi_inf_3dhp')

                    if False:
                        from renderer import viewer2D
                        viewer2D.Vis_Skeleton_2D_general(openpose[:,:2])

                    S = np.zeros([24,4])
                    S[joints_idx] = np.hstack([S17, np.ones([17,1])])

                    # because of the dataset size, we only keep every 10th frame
                    counter += 1
                    if counter % 10 != 1:
                        continue

                    # store the data
                    imgnames_.append(img_view)
                    centers_.append(center)
                    scales_.append(scale)
                    parts_.append(part)         #2dJoint    24,3
                    Ss_.append(S)               #3D joint 24,4
                    S17s_.append(S17)               #3D joint 17,4...#Debug
                    openposes_.append(openpose)     #openpose 25,3

                    if len(S17s_)==105:     #Debug
                        break#Debug
                    
                
                break#Debug
            
        
            break#Debug

        break          #Debug     
    # store the data struct
    if not os.path.isdir(out_path):
        os.makedirs(out_path)
    out_file = os.path.join(out_path, 'mpi_inf_3dhp_train.npz')
    if fits_3d is not None:
        fits_3d = np.load(fits_3d)


        if True:    #Visualize SMPl
            from humanModelViewer import smplModelViewer
            frameLen=100
            poses = fits_3d['pose'][:frameLen]
            shape = fits_3d['shape'][:frameLen]
            meshes, j_coco19, _ = smplModelViewer(poses, shape,returnUnit='cm', bCentering=False, jointType='smplcoco_total26', modelType='f')
            glViewer.setMeshData([meshes],bComputeNormal=True)

            #Skeleton
            S17s_ = np.array(S17s_[:frameLen]) *100
            skelVis = np.reshape(S17s_, (S17s_.shape[0],-1) )            #N,dim
            skelVis = np.swapaxes(skelVis, 0,1)
            glViewer.setSkeleton([skelVis])

            
            glViewer.show()




        np.savez(out_file, imgname=imgnames_,
                           center=centers_,
                           scale=scales_,
                           part=parts_,
                           pose=fits_3d['pose'],
                           shape=fits_3d['shape'],
                           has_smpl=fits_3d['has_smpl'],
                           S=Ss_,
                           openpose=openposes_)
    else:
        np.savez(out_file, imgname=imgnames_,
                           center=centers_,
                           scale=scales_,
                           part=parts_,
                           S=Ss_,
                           openpose=openposes_)        
Example #2
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch[
            'keypoints']  # 2D keypoints           #[N,49,3]
        gt_pose = input_batch[
            'pose']  # SMPL pose parameters                #[N,72]
        gt_betas = input_batch[
            'betas']  # SMPL beta parameters              #[N,10]
        gt_joints = input_batch[
            'pose_3d']  # 3D pose                        #[N,24,4]
        has_smpl = input_batch['has_smpl'].byte(
        ) == 1  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        ) == 1  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        #Debug temporary scaling for h36m
        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])

        gt_model_joints = gt_out.joints.detach()  #[N, 49, 3]
        gt_vertices = gt_out.vertices

        # else:
        #     gt_out = self.smpl(betas=gt_betas, body_pose=gt_pose[:,3:-6], global_orient=gt_pose[:,:3])

        #     gt_model_joints = gt_out.joints.detach()             #[N, 49, 3]
        #     gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary

        opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name,
                                                            indices.cpu(),
                                                            rot_angle.cpu(),
                                                            is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        # if g_smplx == False:
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])

        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints.detach()

        # else:
        #     opt_output = self.smpl(betas=opt_betas, body_pose=opt_pose[:,3:-6], global_orient=opt_pose[:,:3])

        #     opt_vertices = opt_output.vertices
        #     opt_joints = opt_output.joints.detach()

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity == 0]) > 0:
            assert min(has_smpl[opt_validity == 0])  #All should be True

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity == 0]) > 0:
            assert min(has_smpl[opt_validity == 0])  #All should be True

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose,
            opt_betas,
            opt_cam_t,  #opt_pose (N,72)  (N,10)  opt_cam_t: (N,3)
            0.5 * self.options.img_res *
            torch.ones(batch_size, 2, device=self.device),  #(N,2)   (112, 112)
            gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        # if g_smplx == False: #Original
        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        # else:
        #     pred_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:-2], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)

        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints

        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)

        camera_center = torch.zeros(batch_size, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        # Normalize keypoints to [-1,1]
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        #Weak Projection
        if self.options.bUseWeakProj:
            pred_keypoints_2d = weakProjection_gpu(pred_joints, pred_camera[:,
                                                                            0],
                                                   pred_camera[:,
                                                               1:])  #N, 49, 2

        bFootOriLoss = False
        if bFootOriLoss:  #Ignore hips and hip centers, foot
            # LENGTH_THRESHOLD = 0.0089 #1/112.0     #at least it should be 5 pixel
            #Disable parts
            gt_keypoints_2d[:, 2 + 25, 2] = 0
            gt_keypoints_2d[:, 3 + 25, 2] = 0
            gt_keypoints_2d[:, 14 + 25, 2] = 0

            #Disable Foots
            gt_keypoints_2d[:, 5 + 25, 2] = 0  #Left foot
            gt_keypoints_2d[:, 0 + 25, 2] = 0  #Right foot

        if self.options.run_smplify:

            # Convert predicted rotation matrices to axis-angle
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size, -1)
            # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
            pred_pose[torch.isnan(pred_pose)] = 0.0

            # Run SMPLify optimization starting from the network prediction
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)

            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)
            # print("new_opt_joint_loss{} vs opt_joint_loss{}".format(new_opt_joint_loss))

            if True:  #Visualize opt
                for b in range(batch_size):

                    curImgVis = images[b]  #3,224,224
                    curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                    curImgVis = np.transpose(curImgVis, (1, 2, 0)) * 255.0
                    curImgVis = curImgVis[:, :, [2, 1, 0]]

                    #Denormalize image
                    curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                    viewer2D.ImShow(curImgVis, name='rawIm')
                    originalImg = curImgVis.copy()

                    pred_camera_vis = pred_camera.detach().cpu().numpy()

                    opt_vert_vis = opt_vertices[b].detach().cpu().numpy()
                    opt_vert_vis *= pred_camera_vis[b, 0]
                    opt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    opt_vert_vis *= 112
                    opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}

                    gt_vert_vis = gt_vertices[b].detach().cpu().numpy()
                    gt_vert_vis *= pred_camera_vis[b, 0]
                    gt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    gt_vert_vis *= 112
                    gt_meshes = {'ver': gt_vert_vis, 'f': self.smpl.faces}

                    new_opt_output = self.smpl(
                        betas=new_opt_betas,
                        body_pose=new_opt_pose[:, 3:],
                        global_orient=new_opt_pose[:, :3])
                    new_opt_vertices = new_opt_output.vertices
                    new_opt_joints = new_opt_output.joints
                    new_opt_vert_vis = new_opt_vertices[b].detach().cpu(
                    ).numpy()
                    new_opt_vert_vis *= pred_camera_vis[b, 0]
                    new_opt_vert_vis[:, 0] += pred_camera_vis[
                        b,
                        1]  #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis[:, 1] += pred_camera_vis[
                        b,
                        2]  #no need +1 (or  112). Rendernig has this offset already
                    new_opt_vert_vis *= 112
                    new_opt_meshes = {
                        'ver': new_opt_vert_vis,
                        'f': self.smpl.faces
                    }

                    glViewer.setMeshData(
                        [new_opt_meshes, gt_meshes, new_opt_meshes],
                        bComputeNormal=True)

                    glViewer.setBackgroundTexture(originalImg)
                    glViewer.setWindowSize(curImgVis.shape[1],
                                           curImgVis.shape[0])
                    glViewer.SetOrthoCamera(True)

                    print(has_smpl[b])
                    glViewer.show()

            opt_joint_loss[update] = new_opt_joint_loss[update]
            opt_vertices[update, :] = new_opt_vertices[update, :]
            opt_joints[update, :] = new_opt_joints[update, :]
            opt_pose[update, :] = new_opt_pose[update, :]
            opt_betas[update, :] = new_opt_betas[update, :]
            opt_cam_t[update, :] = new_opt_cam_t[update, :]

            self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(),
                            is_flipped.cpu(),
                            update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())

        else:
            update = torch.zeros(batch_size, device=self.device).byte()

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)

        if self.options.ablation_no_pseudoGT:
            valid_fit[:] = False  #Disable all pseudoGT

        # Add the examples with GT parameters to the list of valid fits
        valid_fit = valid_fit | has_smpl

        # if len(valid_fit) > sum(valid_fit):
        #     print(">> Rejected fit: {}/{}".format(len(valid_fit) - sum(valid_fit), len(valid_fit) ))

        opt_keypoints_2d = perspective_projection(
            opt_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=opt_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)

        opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)

        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        # Compute 3D keypoint loss
        loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints,
                                                  has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)

        #Regularization term for shape
        loss_regr_betas_noReject = torch.mean(pred_betas**2)

        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        if self.options.ablation_loss_2dkeyonly:  #2D keypoint only
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                    self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        elif self.options.ablation_loss_noSMPLloss:  #2D no Pose parameter
            loss = self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean() +\
                self.options.beta_loss_weight * loss_regr_betas_noReject        #Beta regularization

        else:
            loss = self.options.shape_loss_weight * loss_shape +\
                self.options.keypoint_loss_weight * loss_keypoints +\
                self.options.keypoint_loss_weight * loss_keypoints_3d +\
                loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
                ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        # loss = self.options.keypoint_loss_weight * loss_keypoints #Debug: 2d error only
        # print("DEBUG: 2donly loss")
        loss *= 60

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        if self.options.bDebug_visEFT:  #g_debugVisualize:    #Debug Visualize input
            for b in range(batch_size):
                #denormalizeImg
                curImgVis = images[b]  #3,224,224
                curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                curImgVis = np.transpose(curImgVis, (1, 2, 0)) * 255.0
                curImgVis = curImgVis[:, :, [2, 1, 0]]

                #Denormalize image
                curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                viewer2D.ImShow(curImgVis, name='rawIm')
                originalImg = curImgVis.copy()

                # curImgVis = viewer2D.Vis_Skeleton_2D_general(gt_keypoints_2d_orig[b,:,:2].cpu().numpy(), gt_keypoints_2d_orig[b,:,2], bVis= False, image=curImgVis)

                pred_keypoints_2d_vis = pred_keypoints_2d[
                    b, :, :2].detach().cpu().numpy()
                pred_keypoints_2d_vis = 0.5 * self.options.img_res * (
                    pred_keypoints_2d_vis + 1)  #49: (25+24) x 3

                curImgVis = viewer2D.Vis_Skeleton_2D_general(
                    pred_keypoints_2d_vis, bVis=False, image=curImgVis)
                viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)

                #Get camera pred_params
                pred_camera_vis = pred_camera.detach().cpu().numpy()

                ############### Visualize Mesh ###############
                pred_vert_vis = pred_vertices[b].detach().cpu().numpy()
                # meshVertVis = gt_vertices[b].detach().cpu().numpy()
                # meshVertVis = meshVertVis-pelvis        #centering
                pred_vert_vis *= pred_camera_vis[b, 0]
                pred_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis *= 112
                pred_meshes = {'ver': pred_vert_vis, 'f': self.smpl.faces}

                opt_vert_vis = opt_vertices[b].detach().cpu().numpy()
                opt_vert_vis *= pred_camera_vis[b, 0]
                opt_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis *= 112
                opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}

                # glViewer.setMeshData([pred_meshes, opt_meshes], bComputeNormal= True)
                glViewer.setMeshData([pred_meshes, opt_meshes],
                                     bComputeNormal=True)
                # glViewer.setMeshData([opt_meshes], bComputeNormal= True)

                ############### Visualize Skeletons ###############
                #Vis pred-SMPL joint
                pred_joints_vis = pred_joints[
                    b, :, :3].detach().cpu().numpy()  #[N,49,3]
                pred_joints_vis = pred_joints_vis.ravel()[:, np.newaxis]
                #Weak-perspective projection
                pred_joints_vis *= pred_camera_vis[b, 0]
                pred_joints_vis[::3] += pred_camera_vis[b, 1]
                pred_joints_vis[1::3] += pred_camera_vis[b, 2]
                pred_joints_vis *= 112  #112 == 0.5*224
                glViewer.setSkeleton([pred_joints_vis])

                # #GT joint
                gt_jointsVis = gt_joints[b, :, :3].cpu().numpy()  #[N,49,3]
                # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                gt_jointsVis = gt_jointsVis.ravel()[:, np.newaxis]
                gt_jointsVis *= pred_camera_vis[b, 0]
                gt_jointsVis[::3] += pred_camera_vis[b, 1]
                gt_jointsVis[1::3] += pred_camera_vis[b, 2]
                gt_jointsVis *= 112
                glViewer.addSkeleton([gt_jointsVis], jointType='spin')

                # #Vis SMPL's Skeleton
                # gt_smpljointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                # # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                # gt_smpljointsVis = gt_smpljointsVis.ravel()[:,np.newaxis]
                # gt_smpljointsVis*=pred_camera_vis[b,0]
                # gt_smpljointsVis[::3] += pred_camera_vis[b,1]
                # gt_smpljointsVis[1::3] += pred_camera_vis[b,2]
                # gt_smpljointsVis*=112
                # glViewer.addSkeleton( [gt_smpljointsVis])

                # #Vis GT  joint  (not model (SMPL) joint!!)
                # if has_pose_3d[b]:
                #     gt_jointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_jointsVis = gt_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_pelvis = (gt_jointsVis[ 25+2,:] + gt_jointsVis[ 25+3,:]) / 2
                #     # gt_jointsVis = gt_jointsVis- gt_pelvis

                #     gt_jointsVis = gt_jointsVis.ravel()[:,np.newaxis]
                #     gt_jointsVis*=pred_camera_vis[b,0]
                #     gt_jointsVis[::3] += pred_camera_vis[b,1]
                #     gt_jointsVis[1::3] += pred_camera_vis[b,2]
                #     gt_jointsVis*=112

                #     glViewer.addSkeleton( [gt_jointsVis])
                # # glViewer.show()

                glViewer.setBackgroundTexture(originalImg)
                glViewer.setWindowSize(curImgVis.shape[1], curImgVis.shape[0])
                glViewer.SetOrthoCamera(True)
                glViewer.show(0)

                # continue

        return output, losses
Example #3
0
    def train_exemplar_step(self, input_batch):

        self.model.train()

        if self.options.bExemplarMode:
            self.exemplerTrainingMode()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch[
            'keypoints']  # 2D keypoints           #[N,49,3]
        gt_pose = input_batch[
            'pose']  # SMPL pose parameters                #[N,72]
        gt_betas = input_batch[
            'betas']  # SMPL beta parameters              #[N,10]
        gt_joints = input_batch[
            'pose_3d']  # 3D pose                        #[N,24,4]
        has_smpl = input_batch['has_smpl'].byte(
        ) == 1  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        ) == 1  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints  #[N, 49, 3]     #Note this is different from gt_joints!
        gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary
        opt_pose, opt_betas, opt_validity = self.fits_dict[(dataset_name,
                                                            indices.cpu(),
                                                            rot_angle.cpu(),
                                                            is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        # opt_joints = opt_output.joints
        opt_joints = opt_output.joints.detach()  #for smpl-x

        #assuer that non valid opt has GT values
        if len(has_smpl[opt_validity == 0]) > 0:
            assert min(has_smpl[opt_validity == 0])  #All should be True

        # else:       #Assue 3D DB!
        #     opt_pose = gt_pose
        #     opt_betas = gt_betas
        #     opt_vertices = gt_vertices
        #     opt_joints = gt_model_joints

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)  #49: (25+24) x 3

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        # gt_cam_t = estimate_translation(gt_model_joints, gt_keypoints_2d_orig, focal_length=self.focal_length, img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        opt_joint2d_loss = self.smplify.get_fitting_loss(
            opt_pose,
            opt_betas,
            opt_cam_t,  #opt_pose (N,72)  (N,10)  opt_cam_t: (N,3)
            0.5 * self.options.img_res *
            torch.ones(batch_size, 2, device=self.device),  #(N,2)   (112, 112)
            gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        pred_vertices = pred_output.vertices
        pred_joints_3d = pred_output.joints

        # # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # # This camera translation can be used in a full perspective projection
        # pred_cam_t = torch.stack([pred_camera[:,1],
        #                           pred_camera[:,2],
        #                           2*self.focal_length/(self.options.img_res * pred_camera[:,0] +1e-9)],dim=-1)

        # camera_center = torch.zeros(batch_size, 2, device=self.device)
        # pred_keypoints_2d = perspective_projection(pred_joints_3d,
        #                                            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(batch_size, -1, -1),
        #                                            translation=pred_cam_t,
        #                                            focal_length=self.focal_length,
        #                                            camera_center=camera_center)
        # # Normalize keypoints to [-1,1]
        # pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        ### Centering        Vertex..model... more complicated!
        if True:
            pred_pelvis = (pred_joints_3d[:, 27:28, :3] +
                           pred_joints_3d[:, 28:29, :3]) / 2
            pred_joints_3d[:, :, :
                           3] = pred_joints_3d[:, :, :
                                               3] - pred_pelvis  #centering
            pred_vertices = pred_vertices - pred_pelvis

            gt_pelvis = (gt_joints[:, 2:3, :3] + gt_joints[:, 2:3, :3]) / 2
            gt_joints[:, :, :3] = gt_joints[:, :, :3] - gt_pelvis  #centering

            gt_model_pelvis = (gt_model_joints[:, 27:28, :3] +
                               gt_model_joints[:, 28:29, :3]) / 2
            gt_model_joints[:, :, :
                            3] = gt_model_joints[:, :, :
                                                 3] - gt_model_pelvis  #centering
            gt_vertices = gt_vertices - gt_model_pelvis

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        # opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        #Weak projection
        pred_keypoints_2d = weakProjection_gpu(pred_joints_3d, pred_camera[:,
                                                                           0],
                                               pred_camera[:, 1:])  #N, 49, 2

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint2d_loss < self.options.smplify_threshold).to(
            self.device)
        # Add the examples with GT parameters to the list of valid fits
        valid_fit = valid_fit | has_smpl

        # opt_keypoints_2d = perspective_projection(opt_joints,
        #                                           rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(batch_size, -1, -1),
        #                                           translation=opt_cam_t,
        #                                           focal_length=self.focal_length,
        #                                           camera_center=camera_center)
        # opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)

        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)
        # loss_regr_pose, loss_regr_betas = self.smpl_losses(pred_rotmat, pred_betas, gt_pose, gt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints_2d = self.keypoint2d_loss(
            pred_keypoints_2d, gt_keypoints_2d,
            self.options.openpose_train_weight, self.options.gt_train_weight)

        # Compute 3D keypoint loss
        # loss_keypoints_3d = self.keypoint_3d_loss(pred_joints_3d, gt_joints, has_pose_3d)
        loss_keypoints_3d = self.keypoint_3d_loss_modelSkel(
            pred_joints_3d, gt_model_joints[:, 25:, :], has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)
        # loss_shape = self.shape_loss(pred_vertices, gt_vertices, valid_fit)

        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        loss = self.options.shape_loss_weight * loss_shape +\
               self.options.keypoint_loss_weight * loss_keypoints_2d +\
               self.options.keypoint_loss_weight * loss_keypoints_3d +\
               loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
               ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        #NOte!! Loss is defined below again!
        if True:  #Exemplar Loss. 2D only + Keep the original shape + camera regularization
            #At this moment just use opt_betas. Ideally, we can use the beta estimated from direct result
            betaMax = abs(torch.max(abs(opt_betas)).item())
            pred_betasMax = abs(torch.max(abs(pred_betas)).item())
            # print(pred_betasMax)
            if False:  #betaMax<2.0:
                loss_regr_betas_noReject = self.criterion_regr(
                    pred_betas, opt_betas)
            else:
                loss_regr_betas_noReject = torch.mean(pred_betas**2)

            #Prevent bending knee?
            # red_rotmat[0,6,:,:] -


            loss = self.options.keypoint_loss_weight * loss_keypoints_2d  + \
                    self.options.beta_loss_weight * loss_regr_betas_noReject  + \
                    ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        # print(loss_regr_betas)
        loss *= 60
        # print("loss2D: {}, loss3D: {}".format( self.options.keypoint_loss_weight * loss_keypoints_2d,self.options.keypoint_loss_weight * loss_keypoints_3d  )  )

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': 0,  #pred_vertices.detach(),
            'opt_vertices': 0,
            'pred_cam_t': 0,  #pred_cam_t.detach(),
            'opt_cam_t': 0
        }

        #Save result
        output = {}
        output['pred_pose_rotmat'] = pred_rotmat.detach().cpu().numpy()
        output['pred_shape'] = pred_betas.detach().cpu().numpy()
        output['pred_camera'] = pred_camera.detach().cpu().numpy()
        output['opt_pose'] = opt_pose.detach().cpu().numpy()
        output['opt_beta'] = opt_betas.detach().cpu().numpy()
        output['sampleIdx'] = input_batch['sample_index'].detach().cpu().numpy(
        )  #To use loader directly
        output['imageName'] = input_batch['imgname']
        output['scale'] = input_batch['scale'].detach().cpu().numpy()
        output['center'] = input_batch['center'].detach().cpu().numpy()

        #To save new db file
        output['keypoint2d'] = input_batch['keypoints_original'].detach().cpu(
        ).numpy()

        losses = {
            'loss': loss.detach().item(),
            'loss_keypoints': loss_keypoints_2d.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        if self.options.bDebug_visEFT:  #g_debugVisualize:    #Debug Visualize input
            for b in range(batch_size):
                #denormalizeImg
                curImgVis = images[b]  #3,224,224
                curImgVis = self.de_normalize_img(curImgVis).cpu().numpy()
                curImgVis = np.transpose(curImgVis, (1, 2, 0)) * 255.0
                curImgVis = curImgVis[:, :, [2, 1, 0]]

                #Denormalize image
                curImgVis = np.ascontiguousarray(curImgVis, dtype=np.uint8)
                originalImgVis = curImgVis.copy()
                viewer2D.ImShow(curImgVis, name='rawIm')

                curImgVis = viewer2D.Vis_Skeleton_2D_general(
                    gt_keypoints_2d_orig[b, :, :2].cpu().numpy(),
                    gt_keypoints_2d_orig[b, :, 2],
                    bVis=False,
                    image=curImgVis)

                pred_keypoints_2d_vis = pred_keypoints_2d[
                    b, :, :2].detach().cpu().numpy()
                pred_keypoints_2d_vis = 0.5 * self.options.img_res * (
                    pred_keypoints_2d_vis + 1)  #49: (25+24) x 3

                if glViewer.g_bShowSkeleton:
                    curImgVis = viewer2D.Vis_Skeleton_2D_general(
                        pred_keypoints_2d_vis, bVis=False, image=curImgVis)
                viewer2D.ImShow(curImgVis, scale=2.0, waitTime=1)

                #Vis Opt-SMPL joint
                if False:
                    smpl_jointsVis = opt_joints[
                        b, :, :3].cpu().numpy() * 100  #[N,49,3]
                    smpl_jointsVis = smpl_jointsVis.ravel()[:, np.newaxis]
                    glViewer.setSkeleton([smpl_jointsVis])
                    # glViewer.show()

                    #Vis Opt-Vertex
                    meshVertVis = opt_vertices[b].cpu().numpy() * 100
                    meshes = {'ver': meshVertVis, 'f': self.smpl.faces}
                    glViewer.setMeshData([meshes], bComputeNormal=True)

                #Get camera pred_params
                pred_camera_vis = pred_camera.detach().cpu().numpy()

                ############### Visualize Mesh ###############
                pred_vert_vis = pred_vertices[b].detach().cpu().numpy()
                # meshVertVis = gt_vertices[b].detach().cpu().numpy()
                # meshVertVis = meshVertVis-pelvis        #centering
                pred_vert_vis *= pred_camera_vis[b, 0]
                pred_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                pred_vert_vis *= 112
                pred_meshes = {'ver': pred_vert_vis, 'f': self.smpl.faces}

                if has_smpl[b] == False:
                    opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
                    opt_joints[has_smpl, :, :] = gt_model_joints[
                        has_smpl, :, :]

                    opt_model_pelvis = (opt_joints[:, 27:28, :3] +
                                        opt_joints[:, 28:29, :3]) / 2
                    opt_vertices = opt_vertices - opt_model_pelvis

                opt_vert_vis = opt_vertices[b].detach().cpu().numpy()
                opt_vert_vis *= pred_camera_vis[b, 0]
                opt_vert_vis[:, 0] += pred_camera_vis[
                    b,
                    1]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis[:, 1] += pred_camera_vis[
                    b,
                    2]  #no need +1 (or  112). Rendernig has this offset already
                opt_vert_vis *= 112
                opt_meshes = {'ver': opt_vert_vis, 'f': self.smpl.faces}

                # glViewer.setMeshData([pred_meshes, opt_meshes], bComputeNormal= True)
                # glViewer.setMeshData([pred_meshes], bComputeNormal= True)
                glViewer.setMeshData([pred_meshes, opt_meshes],
                                     bComputeNormal=True)

                ############### Visualize Skeletons ###############
                #Vis pred-SMPL joint
                pred_joints_vis = pred_joints_3d[
                    b, :, :3].detach().cpu().numpy()  #[N,49,3]
                pred_joints_vis = pred_joints_vis.ravel()[:, np.newaxis]
                #Weak-perspective projection
                pred_joints_vis *= pred_camera_vis[b, 0]
                pred_joints_vis[::3] += pred_camera_vis[b, 1]
                pred_joints_vis[1::3] += pred_camera_vis[b, 2]
                pred_joints_vis *= 112  #112 == 0.5*224
                glViewer.setSkeleton([pred_joints_vis])

                #Vis SMPL's Skeleton
                # gt_smpljointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                # # gt_pelvis = (gt_smpljointsVis[ 25+2,:] + gt_smpljointsVis[ 25+3,:]) / 2
                # # gt_smpljointsVis = gt_smpljointsVis- gt_pelvis
                # gt_smpljointsVis = gt_smpljointsVis.ravel()[:,np.newaxis]
                # gt_smpljointsVis*=pred_camera_vis[b,0]
                # gt_smpljointsVis[::3] += pred_camera_vis[b,1]
                # gt_smpljointsVis[1::3] += pred_camera_vis[b,2]
                # gt_smpljointsVis*=112
                # glViewer.addSkeleton( [gt_smpljointsVis])

                # #Vis GT  joint  (not model (SMPL) joint!!)
                # if has_pose_3d[b]:
                #     gt_jointsVis = gt_model_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_jointsVis = gt_joints[b,:,:3].cpu().numpy()        #[N,49,3]
                #     # gt_pelvis = (gt_jointsVis[ 25+2,:] + gt_jointsVis[ 25+3,:]) / 2
                #     # gt_jointsVis = gt_jointsVis- gt_pelvis

                #     gt_jointsVis = gt_jointsVis.ravel()[:,np.newaxis]
                #     gt_jointsVis*=pred_camera_vis[b,0]
                #     gt_jointsVis[::3] += pred_camera_vis[b,1]
                #     gt_jointsVis[1::3] += pred_camera_vis[b,2]
                #     gt_jointsVis*=112

                #     glViewer.addSkeleton( [gt_jointsVis])
                # # glViewer.show()

                glViewer.setBackgroundTexture(originalImgVis)
                glViewer.setWindowSize(curImgVis.shape[1], curImgVis.shape[0])
                glViewer.SetOrthoCamera(True)
                glViewer.show(1)

                # continue

        return output, losses