Esempio n. 1
0
class trainer(BaseTrain):
    
    def init(self):
        
        # Create training and testing dataset
        self.train_ds = MGNDataset(self.options, split = 'train')
        self.test_ds  = MGNDataset(self.options, split = 'test')
        
        # test data loader is fixed and disable shuffle as it is unnecessary.
        self.test_data_loader = CheckpointDataLoader( self.test_ds,
                    batch_size  = self.options.batch_size,
                    num_workers = self.options.num_workers,
                    pin_memory  = self.options.pin_memory,
                    shuffle     = False)

        # Create SMPL Mesh (graph) object for GCN
        self.mesh = Mesh(self.options, self.options.num_downsampling)
        self.faces = torch.cat( self.options.batch_size * [
                                self.mesh.faces.to(self.device)[None]
                                ], dim = 0 )
        self.faces = self.faces.type(torch.LongTensor).to(self.device)
        
        # Create SMPL blending model and edges
        self.smpl = SMPL(self.options.smpl_model_path, self.device)
        # self.smplEdge = torch.Tensor(np.load(self.options.smpl_edges_path)) \
        #                 .long().to(self.device)
        
        # create SMPL+D blending model
        self.smplD = Smpl( self.options.smpl_model_path ) 
        
        # read SMPL .bj file to get uv coordinates
        _, self.smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(self.options.smpl_objfile_path)
        uv_coord[:, 1] = 1 - uv_coord[:, 1]
        expUV = uv_coord[tri_uv_ind.flatten()]
        unique, index = np.unique(self.smpl_tri_ind.flatten(), return_index = True)
        self.smpl_verts_uvs = torch.as_tensor(expUV[index,:]).float().to(self.device)
        self.smpl_tri_ind   = torch.as_tensor(self.smpl_tri_ind).to(self.device)
        
        # camera for projection 
        self.perspCam = perspCamera() 

        # mean and std of displacements
        self.dispPara = \
            torch.Tensor(np.load(self.options.MGN_offsMeanStd_path)).to(self.device)
        
        # load average pose and shape and convert to camera coodinate;
        # avg pose is decided by the image id we use for training (0-11) 
        avgPose_objCoord = np.load(self.options.MGN_avgPose_path)
        avgPose_objCoord[:3] = rotationMatrix_to_axisAngle(    # for 0,6, front only
            torch.tensor([[[1,  0, 0],
                           [0, -1, 0],
                           [0,  0,-1]]]))
        self.avgPose = \
            axisAngle_to_Rot6d(
                torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3)
                    ).reshape(1, -1).to(self.device)
        self.avgBeta = \
            torch.Tensor(
                np.load(self.options.MGN_avgBeta_path)[None]).to(self.device)
        self.avgCam  = torch.Tensor([1.2755, 0, 0])[None].to(self.device)    # 1.2755 is for our settings
        
        self.model = frameVIBE(
            self.options.smpl_model_path, 
            self.mesh,                
            self.avgPose,
            self.avgBeta,
            self.avgCam,
            self.options.num_channels,
            self.options.num_layers ,  
            self.smpl_verts_uvs,
            self.smpl_tri_ind
            ).to(self.device)
            
        # Setup a optimizer for models
        self.optimizer = torch.optim.Adam(
            params=list(self.model.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)
        
        self.criterion = VIBELoss(
            e_loss_weight=50,         # for kp 2d, help to estimate camera, global orientation
            e_3d_loss_weight=1,       # for kp 3d, bvt
            e_pose_loss_weight=10,     # for body pose parameters
            e_shape_loss_weight=1,   # for body shape parameters
            e_disp_loss_weight=1,   # for displacements 
            e_tex_loss_weight=1,       # for uv image 
            d_motion_loss_weight=0
            )
        
        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {self.options.model: self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        
    def convert_GT(self, input_batch):     
        
        smplGT = torch.cat([
            # GTcamera is not used, here is a placeholder
            torch.zeros([self.options.batch_size,3]).to(self.device),
            # the global rotation is incorrect as it is under the original 
            # coordinate system. The rest and betas are fine.
            rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(-1, 72),   # 24 * 6 = 144
            input_batch['GTsmplParas']['betas']],
            dim = 1).float()
        
        # vertices in the original coordinates
        vertices = self.smpl(
            pose = rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(self.options.batch_size, 72),
            beta = input_batch['GTsmplParas']['betas'].float(),
            trans = input_batch['GTsmplParas']['trans']
            )
        
        # get joints in 3d and 2d in the current camera coordiante
        joints_3d = self.smpl.get_joints(vertices.float())
        joints_2d, _, joints_3d = self.perspCam(
            fx = input_batch['GTcamera']['f_rot'][:,0,0], 
            fy = input_batch['GTcamera']['f_rot'][:,0,0], 
            cx = 112, 
            cy = 112, 
            rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(),  
            translation = input_batch['GTcamera']['t'][:,None,:].float(), 
            points = joints_3d,
            visibilityOn = False,
            output3d = True
            )
        joints_3d = (joints_3d - input_batch['GTcamera']['t'][:,None,:]).float() # remove shifts
        
        # visind = 1
        # img = torch.zeros([224,224])
        # joints_2d[visind][joints_2d[visind].round() >= 224] = 223 
        # joints_2d[visind][joints_2d[visind].round() < 0] = 0 
        # img[joints_2d[visind][:,1].round().long(), joints_2d[visind][:,0].round().long()] = 1
        # plt.imshow(img)
        # plt.imshow(input_batch['img'][visind].cpu().permute(1,2,0))
        
        # convert to [-1, +1]
        joints_2d = (torch.cat(joints_2d, dim=0).reshape([self.options.batch_size, 24, 2]) - 112)/112
        joints_2d = joints_2d.float()
        
        # convert vertices to current camera coordiantes
        _,_,vertices = self.perspCam(
            fx = input_batch['GTcamera']['f_rot'][:,0,0], 
            fy = input_batch['GTcamera']['f_rot'][:,0,0], 
            cx = 112, 
            cy = 112, 
            rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(),  
            translation = input_batch['GTcamera']['t'][:,None,:].float(), 
            points = vertices,
            visibilityOn = False,
            output3d = True
            )
        vertices = (vertices - input_batch['GTcamera']['t'][:,None,:]).float()
        
        # prove the correctness of coord sys,
        # points = vertices (vertices before shift)
        # localcam = perspCamera(smpl_obj = False)    # disable additional trans
        # img_points, _ = localcam(
        #     fx = input_batch['GTcamera']['f_rot'][:,0,0], 
        #     fy = input_batch['GTcamera']['f_rot'][:,0,0], 
        #     cx = 112, 
        #     cy = 112, 
        #     rotation = torch.eye(3)[None].repeat_interleave(points.shape[0], dim = 0).to('cuda'),  
        #     translation = torch.zeros([points.shape[0],1,3]).to('cuda'), 
        #     points = points.float(),
        #     visibilityOn = False,
        #     output3d = False
        #     )
        # img = torch.zeros([224,224])
        # img_points[visind][img_points[visind].round() >= 224] = 223 
        # img_points[visind][img_points[visind].round() < 0] = 0 
        # img[img_points[visind][:,1].round().long(), img_points[visind][:,0].round().long()] = 1
        # plt.imshow(img)
                
        # prove the correctness of displacements
        # import open3d as o3d
        # points = vertices
        # ind = 0
        # o3dMesh = o3d.geometry.TriangleMesh()
        # o3dMesh.vertices = o3d.utility.Vector3dVector(points[ind].cpu())
        # o3dMesh.triangles= o3d.utility.Vector3iVector(self.faces[0].cpu())
        # o3dMesh.compute_vertex_normals()     
        # o3d.visualization.draw_geometries([o3dMesh])
    
        # self.renderer = renderer(batch_size = 1)
        # images = self.renderer(
        #     verts = vertices[0][None],
        #     faces = self.faces[0][None],
        #     verts_uvs = self.smpl_verts_uvs[None].repeat_interleave(2, dim=0)[0][None],
        #     faces_uvs = self.smpl_tri_ind[None].repeat_interleave(2, dim=0)[0][None],
        #     tex_image = tex[None],
        #     R = torch.eye(3)[None].repeat_interleave(2, dim = 0).to('cuda')[0][None],
        #     T = input_batch['GTcamera']['t'].float()[0][None],
        #     f = input_batch['GTcamera']['f_rot'][:,0,0][:,None].float()[0][None],
        #     C = torch.ones([2,2]).to('cuda')[0][None]*112,
        #     imgres = 224
        # )
        
        GT = {'img' : input_batch['img'].float(),
              'img_orig': input_batch['img_orig'].float(),
              'imgname' : input_batch['imgname'],
              
              'camera': input_batch['GTcamera'],                   # wcd 
              'theta': smplGT.float(),        
              
              # joints_2d is col, row == x, y; joints_3d is x,y,z
              'target_2d': joints_2d.float(),
              'target_3d': joints_3d.float(),
              'target_bvt': vertices.float(),   # body vertices
              'target_dp': input_batch['GToffsets_t']['offsets'].float(),    # smpl cd t-pose
              
              'target_uv': input_batch['GTtextureMap'].float()
              }
            
        return GT
    
    def train_step(self, input_batch):
        """Training step."""
        self.model.train()
        
        # prepare data
        GT = self.convert_GT(input_batch)    
  
        # forward pass
        pred = self.model(GT['img'], GT['img_orig'])
        
        # loss
        gen_loss, loss_dict = self.criterion(
            generator_outputs=pred,
            data_2d = GT['target_2d'],
            data_3d = GT,
            )
        # print(gen_loss)
        out_args = loss_dict
        out_args['loss'] = gen_loss
        out_args['prediction'] = pred 
        
        # save one training sample for vis
        if (self.step_count+1) % self.options.summary_steps == 0:
            with torch.no_grad():
                self.save_sample(GT, pred[0], saveTo = 'train')

        return out_args

    def test(self):
        """"Testing process"""
        self.model.eval()    
        
        test_loss, test_tex_loss = 0, 0
        test_loss_pose, test_loss_shape = 0, 0
        test_loss_kp_2d, test_loss_kp_3d = 0, 0
        test_loss_dsp_3d, test_loss_bvt_3d = 0, 0
        for step, batch in enumerate(tqdm(self.test_data_loader, desc='Test',
                                          total=len(self.test_ds) // self.options.batch_size,
                                          initial=self.test_data_loader.checkpoint_batch_idx),
                                     self.test_data_loader.checkpoint_batch_idx):
            # convert data devices
            batch_toDEV = {}
            for key, val in batch.items():
                if isinstance(val, torch.Tensor):
                    batch_toDEV[key] = val.to(self.device)
                else:
                        batch_toDEV[key] = val
                if isinstance(val, dict):
                    batch_toDEV[key] = {}
                    for k, v in val.items():
                        if isinstance(v, torch.Tensor):
                            batch_toDEV[key][k] = v.to(self.device)
            # prepare data
            GT = self.convert_GT(batch_toDEV)

            with torch.no_grad():    # disable grad
                # forward pass
                pred = self.model(GT['img'], GT['img_orig'])
                
                # loss
                gen_loss, loss_dict = self.criterion(
                    generator_outputs=pred,
                    data_2d = GT['target_2d'],
                    data_3d = GT,
                    )
                # save for comparison
                if step == 0:
                    self.save_sample(GT, pred[0])
                
            test_loss += gen_loss
            test_loss_pose  += loss_dict['loss_pose']
            test_loss_shape += loss_dict['loss_shape']
            test_loss_kp_2d += loss_dict['loss_kp_2d']
            test_loss_kp_3d += loss_dict['loss_kp_3d']
            test_loss_bvt_3d += loss_dict['loss_bvt_3d']
            test_loss_dsp_3d += loss_dict['loss_dsp_3d']
            test_tex_loss += loss_dict['loss_tex']
                        
        test_loss = test_loss/len(self.test_data_loader)
        test_loss_pose  = test_loss_pose/len(self.test_data_loader)
        test_loss_shape = test_loss_shape/len(self.test_data_loader)
        test_loss_kp_2d = test_loss_kp_2d/len(self.test_data_loader)
        test_loss_kp_3d = test_loss_kp_3d/len(self.test_data_loader)
        test_loss_bvt_3d = test_loss_bvt_3d/len(self.test_data_loader)
        test_loss_dsp_3d = test_loss_dsp_3d/len(self.test_data_loader)
        test_tex_loss = test_tex_loss/len(self.test_data_loader)
        
        lossSummary = {'test_loss': test_loss, 
                       'test_loss_pose' : test_loss_pose,
                       'test_loss_shape' : test_loss_shape,
                       'test_loss_kp_2d' : test_loss_kp_2d,
                       'test_loss_kp_3d' : test_loss_kp_3d,
                       'test_loss_bvt_3d': test_loss_bvt_3d,
                       'test_loss_dsp_3d': test_loss_dsp_3d,
                       'test_tex_loss': test_tex_loss
                       }
        self.test_summaries(lossSummary)
        
        if test_loss < self.last_test_loss:
            self.last_test_loss = test_loss
            self.saver.save_checkpoint(self.models_dict, 
                                       self.optimizers_dict, 
                                       0, 
                                       step+1, 
                                       self.options.batch_size, 
                                       self.test_data_loader.sampler.dataset_perm, 
                                       self.step_count) 
            tqdm.write('Better test checkpoint saved')
        
    def save_sample(self, data, prediction, ind = 0, saveTo = 'test'):
        """Saving a sample for visualization and comparison"""
        
        assert saveTo in ('test', 'train'), 'save to train or test folder'
        
        folder = pjn(self.options.summary_dir, '%s/'%(saveTo))
        _input = pjn(folder, 'input_image.png')
        # batchs = self.options.batch_size
        
        # save the input image, if not saved
        if not isfile(_input):
            plt.imsave(_input, data['img'][ind].cpu().permute(1,2,0).clamp(0,1).numpy())
            
        # overlap the prediction to the real image; as the mesh .obj has diff 
        # ft/uv coord from MPI lib and MGN, we flip the predicted texture.
        save_renderer = simple_renderer(batch_size = 1)
        predTrans = torch.stack(
            [prediction['theta'][ind, 1],
             prediction['theta'][ind, 2],
             2 * 1000. / (224. * prediction['theta'][ind, 0] + 1e-9)], dim=-1)
        tex = prediction['tex_image'][ind].flip(dims=(0,))[None]
        pred_img = save_renderer(
            verts = prediction['verts'][ind][None],
            faces = self.faces[ind][None],
            verts_uvs = self.smpl_verts_uvs[None],
            faces_uvs = self.smpl_tri_ind[None],
            tex_image = tex,
            R = torch.eye(3)[None].to('cuda'),
            T = predTrans,
            f = torch.ones([1,1]).to('cuda')*1000,
            C = torch.ones([1,2]).to('cuda')*112,
            imgres = 224)
        
        overlayimg = 0.9*pred_img[0,:,:,:3] + 0.1*data['img'][ind].permute(1,2,0)
        save_path = pjn(folder,'overlay_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (overlayimg.clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'tex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (pred_img[0,:,:,:3].clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'unwarptex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (prediction['unwarp_tex'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'predtex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (prediction['tex_image'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8'))


        # create predicted posed undressed body vetirces       
        offPred_t  = (prediction['verts_disp'][ind]*self.dispPara[1]+self.dispPara[0]).cpu()[None,:] 
        predDressbody = create_smplD_psbody(
            self.smplD, offPred_t, 
            prediction['theta'][ind][3:75][None].cpu(), 
            prediction['theta'][ind][75:][None].cpu(), 
            0, 
            rtnMesh=True)[1]
                            
        # Create meshes and save 
        savepath = pjn(folder,'%s_iters%d.obj'%\
                       (data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        predDressbody.write_obj(savepath)
def run_evaluation(model, dataset_name, dataset, 
                   mesh, batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """
    
    renderer = PartRenderer()
    
    # Create SMPL model
    smpl = SMPL().cuda()
    
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Transfer model to the GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()
    
    # 2D pose metrics
    pose_2d_m = np.zeros(len(dataset))
    pose_2d_m_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    eval_2d_pose = False
    eval_shape = False

    # Choose appropriate evaluation for each dataset
    if dataset_name == 'up-3d':
        eval_shape = True
    elif dataset_name == 'lsp':
        eval_2d_pose = True

    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl(gt_pose, gt_betas)
        images = batch['img'].to(torch.device('cuda'))
        curr_batch_size = images.shape[0]
        gt_keypoints_2d = batch['keypoints'].cpu().numpy()
        
        # Run inference
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(images)

        # If mask or part evaluation, render the mask and part images
        if eval_2d_pose:
            mask, parts = renderer(pred_vertices, camera)

        # 2D pose evaluation (for LSP)
        if eval_2d_pose:
            for i in range(curr_batch_size):
                gt_kp = gt_keypoints_2d[i, list(range(14))]
                # Get 3D and projected 2D keypoints from the regressed shape
                pred_keypoints_3d = smpl.get_joints(pred_vertices)
                pred_keypoints_2d = orthographic_projection(pred_keypoints_3d, camera)[:, :, :2]
                pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl)
                pred_keypoints_2d_smpl = orthographic_projection(pred_keypoints_3d_smpl, camera.detach())[:, :, :2]
                pred_kp = pred_keypoints_2d.cpu().numpy()[i, list(range(14))]
                pred_kp_smpl = pred_keypoints_2d_smpl.cpu().numpy()[i, list(range(14))]
                # Compute 2D pose losses
                loss_2d_pose = np.sum((gt_kp[: , :2] - pred_kp)**2)
                loss_2d_pose_smpl = np.sum((gt_kp[: , :2] - pred_kp_smpl)**2)

                #print(gt_kp)
                #print()
                #print(pred_kp_smpl)
                #raw_input("Press Enter to continue...")

                pose_2d_m[step * batch_size + i] = loss_2d_pose
                pose_2d_m_smpl[step * batch_size + i] = loss_2d_pose_smpl

        # Shape evaluation (Mean per-vertex error)
        if eval_shape:
            se = torch.sqrt(((pred_vertices - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            se_smpl = torch.sqrt(((pred_vertices_smpl - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            shape_err[step * batch_size:step * batch_size + curr_batch_size] = se
            shape_err_smpl[step * batch_size:step * batch_size + curr_batch_size] = se_smpl

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_2d_pose:
                print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m[:step * batch_size].mean()))
                print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl[:step * batch_size].mean()))
                print()
            if eval_shape:
                print('Shape Error (NonParam): ' + str(1000 * shape_err[:step * batch_size].mean()))
                print('Shape Error (Param): ' + str(1000 * shape_err_smpl[:step * batch_size].mean()))
                print()

    # Print and store final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_2d_pose:
        print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m.mean()))
        print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl.mean()))
        print()
        # store results
        #np.savez("../eval-output/CMR_no_extra_lsp_model_alt.npz", imgnames = dataset.imgname, kp_2d_err_graph = pose_2d_m, kp_2d_err_smpl = pose_2d_m_smpl)
    if eval_shape:
        print('Shape Error (NonParam): ' + str(1000 * shape_err.mean()))
        print('Shape Error (Param): ' + str(1000 * shape_err_smpl.mean()))
        print()
Esempio n. 3
0
class Trainer(BaseTrainer):
    """Trainer object.
    Inherits from BaseTrainer that sets up logging, saving/restoring checkpoints etc.
    """
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset, self.options)

        # create Mesh object
        self.mesh = Mesh()
        self.faces = self.mesh.faces.to(self.device)

        # create GraphCNN
        self.graph_cnn = GraphCNN(self.mesh.adjmat,
                                  self.mesh.ref_vertices.t(),
                                  num_channels=self.options.num_channels,
                                  num_layers=self.options.num_layers).to(
                                      self.device)

        # SMPL Parameter regressor
        self.smpl_param_regressor = SMPLParamRegressor().to(self.device)

        # Setup a joint optimizer for the 2 models
        self.optimizer = torch.optim.Adam(
            params=list(self.graph_cnn.parameters()) +
            list(self.smpl_param_regressor.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)

        # SMPL model
        self.smpl = SMPL().to(self.device)

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {
            'graph_cnn': self.graph_cnn,
            'smpl_param_regressor': self.smpl_param_regressor
        }
        self.optimizers_dict = {'optimizer': self.optimizer}

        # Renderer for visualization
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

    def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d):
        """Compute 2D reprojection loss on the keypoints.
        The confidence is binary and indicates whether the keypoints exist or not.
        The available keypoints are different for each dataset.
        """
        conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
        loss = (conf * self.criterion_keypoints(
            pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
        return loss

    def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d,
                         has_pose_3d):
        """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available.
        The loss is weighted by the confidence
        """
        conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
        gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
        conf = conf[has_pose_3d == 1]
        pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
        if len(gt_keypoints_3d) > 0:
            gt_pelvis = (gt_keypoints_3d[:, 2, :] +
                         gt_keypoints_3d[:, 3, :]) / 2
            gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
            pred_pelvis = (pred_keypoints_3d[:, 2, :] +
                           pred_keypoints_3d[:, 3, :]) / 2
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
            return (conf * self.criterion_keypoints(pred_keypoints_3d,
                                                    gt_keypoints_3d)).mean()
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def shape_loss(self, pred_vertices, gt_vertices, has_smpl):
        """Compute per-vertex loss on the shape for the examples that SMPL annotations are available."""
        pred_vertices_with_shape = pred_vertices[has_smpl == 1]
        gt_vertices_with_shape = gt_vertices[has_smpl == 1]
        if len(gt_vertices_with_shape) > 0:
            return self.criterion_shape(pred_vertices_with_shape,
                                        gt_vertices_with_shape)
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
                    has_smpl):
        """Compute SMPL parameter loss for the examples that SMPL annotations are available."""
        pred_rotmat_valid = pred_rotmat[has_smpl == 1].view(-1, 3, 3)
        gt_rotmat_valid = rodrigues(gt_pose[has_smpl == 1].view(-1, 3))
        pred_betas_valid = pred_betas[has_smpl == 1]
        gt_betas_valid = gt_betas[has_smpl == 1]
        if len(pred_rotmat_valid) > 0:
            loss_regr_pose = self.criterion_regr(pred_rotmat_valid,
                                                 gt_rotmat_valid)
            loss_regr_betas = self.criterion_regr(pred_betas_valid,
                                                  gt_betas_valid)
        else:
            loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
            loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
        return loss_regr_pose, loss_regr_betas

    def train_step(self, input_batch):
        """Training step."""
        self.graph_cnn.train()
        self.smpl_param_regressor.train()

        # Grab data from the batch
        gt_keypoints_2d = input_batch['keypoints']
        gt_keypoints_3d = input_batch['pose_3d']
        gt_pose = input_batch['pose']
        gt_betas = input_batch['betas']
        has_smpl = input_batch['has_smpl']
        has_pose_3d = input_batch['has_pose_3d']
        images = input_batch['img']

        # Render vertices using SMPL parameters
        gt_vertices = self.smpl(gt_pose, gt_betas)
        batch_size = gt_vertices.shape[0]

        # Feed image in the GraphCNN
        # Returns subsampled mesh and camera parameters
        pred_vertices_sub, pred_camera = self.graph_cnn(images)

        # Upsample mesh in the original size
        pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2))

        # Prepare input for SMPL Parameter regressor
        # The input is the predicted and template vertices subsampled by a factor of 4
        # Notice that we detach the GraphCNN
        x = pred_vertices_sub.transpose(1, 2).detach()
        x = torch.cat(
            [x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)],
            dim=-1)

        # Estimate SMPL parameters and render vertices
        pred_rotmat, pred_shape = self.smpl_param_regressor(x)
        pred_vertices_smpl = self.smpl(pred_rotmat, pred_shape)

        # Get 3D and projected 2D keypoints from the regressed shape
        pred_keypoints_3d = self.smpl.get_joints(pred_vertices)
        pred_keypoints_2d = orthographic_projection(pred_keypoints_3d,
                                                    pred_camera)[:, :, :2]
        pred_keypoints_3d_smpl = self.smpl.get_joints(pred_vertices_smpl)
        pred_keypoints_2d_smpl = orthographic_projection(
            pred_keypoints_3d_smpl, pred_camera.detach())[:, :, :2]

        # Compute losses

        # GraphCNN losses
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d)
        loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d,
                                                  gt_keypoints_3d, has_pose_3d)
        loss_shape = self.shape_loss(pred_vertices, gt_vertices, has_smpl)

        # SMPL regressor losses
        loss_keypoints_smpl = self.keypoint_loss(pred_keypoints_2d_smpl,
                                                 gt_keypoints_2d)
        loss_keypoints_3d_smpl = self.keypoint_3d_loss(pred_keypoints_3d_smpl,
                                                       gt_keypoints_3d,
                                                       has_pose_3d)
        loss_shape_smpl = self.shape_loss(pred_vertices_smpl, gt_vertices,
                                          has_smpl)
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_shape, gt_pose, gt_betas, has_smpl)

        # Add losses to compute the total loss
        loss = loss_shape_smpl + loss_keypoints_smpl + loss_keypoints_3d_smpl +\
               loss_regr_pose + 0.1 * loss_regr_betas + loss_shape + loss_keypoints + loss_keypoints_3d

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments to be used for visualization in a list
        out_args = [
            pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d,
            pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl,
            loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d,
            loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss
        ]
        out_args = [arg.detach() for arg in out_args]
        return out_args

    def train_summaries(self, input_batch, pred_vertices, pred_vertices_smpl,
                        pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl,
                        loss_shape, loss_shape_smpl, loss_keypoints,
                        loss_keypoints_smpl, loss_keypoints_3d,
                        loss_keypoints_3d_smpl, loss_regr_pose,
                        loss_regr_betas, loss):
        """Tensorboard logging."""
        gt_keypoints_2d = input_batch['keypoints'].cpu().numpy()

        rend_imgs = []
        rend_imgs_smpl = []
        batch_size = pred_vertices.shape[0]
        # Do visualization for the first 4 images of the batch
        for i in range(min(batch_size, 4)):
            img = input_batch['img_orig'][i].cpu().numpy().transpose(1, 2, 0)
            # Get LSP keypoints from the full list of keypoints
            gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp]
            pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i,
                                                                 self.to_lsp]
            pred_keypoints_2d_smpl_ = pred_keypoints_2d_smpl.cpu().numpy()[
                i, self.to_lsp]
            # Get GraphCNN and SMPL vertices for the particular example
            vertices = pred_vertices[i].cpu().numpy()
            vertices_smpl = pred_vertices_smpl[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            # Visualize reconstruction and detected pose
            rend_img = visualize_reconstruction(img, self.options.img_res,
                                                gt_keypoints_2d_, vertices,
                                                pred_keypoints_2d_, cam,
                                                self.renderer)
            rend_img_smpl = visualize_reconstruction(img, self.options.img_res,
                                                     gt_keypoints_2d_,
                                                     vertices_smpl,
                                                     pred_keypoints_2d_smpl_,
                                                     cam, self.renderer)
            rend_img = rend_img.transpose(2, 0, 1)
            rend_img_smpl = rend_img_smpl.transpose(2, 0, 1)
            rend_imgs.append(torch.from_numpy(rend_img))
            rend_imgs_smpl.append(torch.from_numpy(rend_img_smpl))
        rend_imgs = make_grid(rend_imgs, nrow=1)
        rend_imgs_smpl = make_grid(rend_imgs_smpl, nrow=1)

        # Save results in Tensorboard
        self.summary_writer.add_image('imgs', rend_imgs, self.step_count)
        self.summary_writer.add_image('imgs_smpl', rend_imgs_smpl,
                                      self.step_count)
        self.summary_writer.add_scalar('loss_shape', loss_shape,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_shape_smpl', loss_shape_smpl,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_pose', loss_regr_pose,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_betas', loss_regr_betas,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints', loss_keypoints,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_smpl',
                                       loss_keypoints_smpl, self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d', loss_keypoints_3d,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d_smpl',
                                       loss_keypoints_3d_smpl, self.step_count)
        self.summary_writer.add_scalar('loss', loss, self.step_count)
Esempio n. 4
0
def run_evaluation(model, args, dataset, mesh):
    """Run evaluation on the datasets and metrics we report in the paper. """

    # Create SMPL model
    smpl = SMPL().cuda()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M)).float()

    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    # Transfer model to the GPU
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()

    # predictions
    all_kps = {}

    # Iterate over the entire dataset
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=args.num_imgs)):

        # Get ground truth annotations from the batch
        images = batch['img'].to(device)
        curr_batch_size = images.shape[0]

        # Run inference
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(
                images)
            pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl)
            pred_keypoints_2d_smpl = orthographic_projection(
                pred_keypoints_3d_smpl,
                camera.detach())[:, :, :2].cpu().data.numpy()

        eval_part = np.zeros((1, 19, 2))

        # we use custom keypoints for evaluation: MPII + COCO face joints
        # see paper / supplementary for details
        eval_part[0, :14, :] = pred_keypoints_2d_smpl[0][:14]
        eval_part[0, 14:, :] = pred_keypoints_2d_smpl[0][19:]

        all_kps[step] = eval_part
        if args.write_imgs == 'True':
            renderer = Renderer(faces=smpl.faces.cpu().numpy())
            write_imgs(batch, pred_keypoints_2d_smpl, pred_vertices_smpl,
                       camera, args, renderer)

    if args.eval_pck == 'True':
        gt_kp_path = os.path.join(cfg.BASE_DATA_DIR, args.dataset,
                                  args.crop_setting, 'keypoints.pkl')
        log_dir = os.path.join(cfg.BASE_DATA_DIR, 'cmr_pck_results.txt')
        with open(gt_kp_path, 'rb') as f:
            gt = pkl.load(f)

        calc = CalcPCK(
            all_kps,
            gt,
            num_imgs=cfg.DATASET_SIZES[args.dataset][args.crop_setting],
            log_dir=log_dir,
            dataset=args.dataset,
            crop_setting=args.crop_setting,
            pck_eval_threshold=args.pck_eval_threshold)
        calc.eval()