Ejemplo n.º 1
0
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset, self.options)

        # create Mesh object
        self.mesh = Mesh()
        self.faces = self.mesh.faces.to(self.device)

        # create GraphCNN
        self.graph_cnn = GraphCNN(self.mesh.adjmat,
                                  self.mesh.ref_vertices.t(),
                                  num_channels=self.options.num_channels,
                                  num_layers=self.options.num_layers).to(
                                      self.device)

        # SMPL Parameter regressor
        self.smpl_param_regressor = SMPLParamRegressor().to(self.device)

        # Setup a joint optimizer for the 2 models
        self.optimizer = torch.optim.Adam(
            params=list(self.graph_cnn.parameters()) +
            list(self.smpl_param_regressor.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)

        # SMPL model
        self.smpl = SMPL().to(self.device)

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {
            'graph_cnn': self.graph_cnn,
            'smpl_param_regressor': self.smpl_param_regressor
        }
        self.optimizers_dict = {'optimizer': self.optimizer}

        # Renderer for visualization
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)
Ejemplo n.º 2
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True)
        self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)      # feature extraction model
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                            lr = self.options.lr,
                                            weight_decay=0)
        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size = 16,
                         create_transl=False).to(self.device)
        # per vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # keypoints loss including 2D and 3D
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # SMPL parameters loss if we have
        self.criterion_regr = nn.MSELoss().to(self.device)

        self.models_dict = {'model':self.model}
        self.optimizers_dict = {'optimizer':self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH
        # initialize MVSMPLify
        self.mvsmplify = MVSMPLify(step_size=1e-2, batch_size=16, num_iters=100,focal_length=self.focal_length)
        print(self.options.pretrained_checkpoint)
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(checkpoint_file = self.options.pretrained_checkpoint)
        #load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)
        # create renderer
        self.renderer = Renderer(focal_length=self.focal_length, img_res = 224, faces=self.smpl.faces)
Ejemplo n.º 3
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True)

        self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.options.lr,
                                          weight_decay=0)
        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size=self.options.batch_size,
                         create_transl=False).to(self.device)
        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.models_dict = {'model': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH
        self.conf_thresh = self.options.conf_thresh

        # Initialize SMPLify fitting module
        self.smplify = SMPLify(step_size=1e-2, batch_size=self.options.batch_size, num_iters=self.options.num_smplify_iters, focal_length=self.focal_length, prior_mul=0.1, conf_thresh=self.conf_thresh)
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(checkpoint_file=self.options.pretrained_checkpoint)

        # Load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)

        # Create renderer
        self.renderer = Renderer(focal_length=self.focal_length, img_res=self.options.img_res, faces=self.smpl.faces)
Ejemplo n.º 4
0
 def __init__(self, focal_length=5000., render_res=224):
     # Parameters for rendering
     self.focal_length = focal_length
     self.render_res = render_res
     # We use Neural 3D mesh renderer for rendering masks and part segmentations
     self.neural_renderer = nr.Renderer(dist_coeffs=None, orig_size=self.render_res,
                                        image_size=render_res,
                                        light_intensity_ambient=1,
                                        light_intensity_directional=0,
                                        anti_aliasing=False)
     self.faces = SMPL().faces.cuda().int()
     textures = np.load(cfg.VERTEX_TEXTURE_FILE)
     self.textures = torch.from_numpy(textures).cuda().float()
     self.cube_parts = torch.cuda.FloatTensor(np.load(cfg.CUBE_PARTS_FILE))
Ejemplo n.º 5
0
def bbox_from_json(bbox_file):
    """Get center and scale of bounding box from bounding box annotations.
    The expected format is [top_left(x), top_left(y), width, height].
    """
    with open(bbox_file, 'r') as f:
        bbox = np.array(json.load(f)['bbox']).astype(np.float32)
    ul_corner = bbox[:2]
    center = ul_corner + 0.5 * bbox[2:]
    # Load pretrained model
    model = hmr(config.SMPL_MEAN_PARAMS).to(device)
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)

    # Load SMPL model
    smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                create_transl=False).to(device)
    model.eval()

    # Setup renderer for visualization
    renderer = Renderer(focal_length=constants.FOCAL_LENGTH,
                        img_res=constants.IMG_RES,
                        faces=smpl.faces)

    # Preprocess input image and generate predictions
    img, norm_img = process_image(args.img,
                                  args.bbox,
                                  args.openpose,
                                  input_res=constants.IMG_RES)
    with torch.no_grad():
        pred_rotmat, pred_betas, pred_camera = model(norm_img.to(device))
        pred_output = smpl(betas=pred_betas,
                           body_pose=pred_rotmat[:, 1:],
                           global_orient=pred_rotmat[:, 0].unsqueeze(1),
                           pose2rot=False)
        pred_vertices = pred_output.vertices

    # Calculate camera parameters for rendering
    camera_translation = torch.stack([
        pred_camera[:, 1], pred_camera[:, 2], 2 * constants.FOCAL_LENGTH /
        (constants.IMG_RES * pred_camera[:, 0] + 1e-9)
    ],
                                     dim=-1)
    camera_translation = camera_translation[0].cpu().numpy()
    pred_vertices = pred_vertices[0].cpu().numpy()
    img = img.permute(1, 2, 0).cpu().numpy()

    width = max(bbox[2], bbox[3])
    scale = width / 200.0
    # make sure the bounding box is rectangular
    return center, scale
Ejemplo n.º 6
0
    def __init__(self, filename='data/mesh_downsampling.npz',
                 num_downsampling=1, nsize=1, device=torch.device('cuda')):
        self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
        self._A = [a.to(device) for a in self._A]
        self._U = [u.to(device) for u in self._U]
        self._D = [d.to(device) for d in self._D]
        self.num_downsampling = num_downsampling

        # load template vertices from SMPL and normalize them
        smpl = SMPL()
        ref_vertices = smpl.v_template
        center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
        ref_vertices -= center
        ref_vertices /= ref_vertices.abs().max().item()

        self._ref_vertices = ref_vertices.to(device)
        self.faces = smpl.faces.int().to(device)
Ejemplo n.º 7
0
 def __init__(self, mesh, num_layers, num_channels, pretrained_checkpoint=None):
     super(CMR, self).__init__()
     self.graph_cnn = GraphCNN(mesh.adjmat, mesh.ref_vertices.t(),
                               num_layers, num_channels)
     self.smpl_param_regressor = SMPLParamRegressor()
     self.smpl = SMPL()
     self.mesh = mesh
     if pretrained_checkpoint is not None:
         checkpoint = torch.load(pretrained_checkpoint)
         try:
             self.graph_cnn.load_state_dict(checkpoint['graph_cnn'])
         except KeyError:
             print('Warning: graph_cnn was not found in checkpoint')
         try:
             self.smpl_param_regressor.load_state_dict(checkpoint['smpl_param_regressor'])
         except KeyError:
             print('Warning: smpl_param_regressor was not found in checkpoint')
Ejemplo n.º 8
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options,
                                     ignore_3d=self.options.ignore_3d,
                                     is_train=True)

        self.model = hmr(config.SMPL_MEAN_PARAMS,
                         pretrained=True).to(self.device)

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.options.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=0.95)

        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size=self.options.batch_size,
                         create_transl=False).to(self.device)

        # consistency loss
        self.criterion_consistency_contrastive = NTXent(
            tau=self.options.tau, kernel=self.options.kernel).to(self.device)
        self.criterion_consistency_mse = nn.MSELoss().to(self.device)
        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.models_dict = {'model': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH

        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

        # Create renderer
        self.renderer = Renderer(focal_length=self.focal_length,
                                 img_res=self.options.img_res,
                                 faces=self.smpl.faces)

        # Create input image flag
        self.input_img = self.options.input_img

        # initialize queue
        self.feat_queue = FeatQueue(max_queue_size=self.options.max_queue_size)
Ejemplo n.º 9
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50, options=None):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)
    
    renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(path_config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    fits_dict = None

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    # joint_mapper_coco = constants.H36M_TO_JCOCO
    joint_mapper_gt = constants.J24_TO_JCOCO

    focal_length = 5000

    num_joints = 17
    num_samples = len(dataset)
    print('dataset length: {}'.format(num_samples))
    all_preds = np.zeros(
        (num_samples, num_joints, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
            if len(options.vis_imname) > 0:
                imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]
                name_hit = False
                for i_n in imgnames:
                    if options.vis_imname in i_n:
                        name_hit = True
                        print('vis: ' + i_n)
                if not name_hit:
                    continue

            images = batch['img'].to(device)

            scale = batch['scale'].numpy()
            center = batch['center'].numpy()

            num_images = images.size(0)

            gt_keypoints_2d = batch['keypoints']  # 2D keypoints
            # De-normalize 2D keypoints from [-1,1] to pixel space
            gt_keypoints_2d_orig = gt_keypoints_2d.clone()
            gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1)

            if options.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
                # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
            elif options.regressor == 'pymaf_net':
                preds_dict, _ = model(images)
                pred_rotmat = preds_dict['smpl_out'][-1]['rotmat'].contiguous().view(-1, 24, 3, 3)
                pred_betas = preds_dict['smpl_out'][-1]['theta'][:, 3:13].contiguous()
                pred_camera = preds_dict['smpl_out'][-1]['theta'][:, :3].contiguous()

            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:, 1:],
                                        global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False)

            # pred_vertices = pred_output.vertices
            pred_J24 = pred_output.joints[:, -24:]
            pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO]

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([pred_camera[:,1],
                                        pred_camera[:,2],
                                        2*constants.FOCAL_LENGTH/(img_res * pred_camera[:, 0] +1e-9)],dim=-1)
            camera_center = torch.zeros(len(pred_JCOCO), 2, device=pred_camera.device)
            pred_keypoints_2d = perspective_projection(pred_JCOCO,
                                                        rotation=torch.eye(3, device=pred_camera.device).unsqueeze(0).expand(len(pred_JCOCO), -1, -1),
                                                        translation=pred_cam_t,
                                                        focal_length=constants.FOCAL_LENGTH,
                                                        camera_center=camera_center)

            coords = pred_keypoints_2d + (img_res / 2.)
            coords = coords.cpu().numpy()

            gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants.J24_TO_JCOCO]
            vert_errors_batch = []
            for i, (gt2d, pred2d) in enumerate(zip(gt_keypoints_coco.cpu().numpy(), coords.copy())):
                vert_error = np.sqrt(np.sum((gt2d[:, :2] - pred2d[:, :2]) ** 2, axis=1))
                vert_error *= gt2d[:, 2]
                vert_mean_error = np.sum(vert_error) / np.sum(gt2d[:, 2] > 0)
                vert_errors_batch.append(10 * vert_mean_error)

            if options.vis_demo:
                imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]

                if options.regressor == 'hmr':
                    iuv_pred = None

                images_vis = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                images_vis = images_vis + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                vis_smpl_iuv(images_vis.cpu().numpy(), pred_camera.cpu().numpy(), pred_output.vertices.cpu().numpy(),
                             smpl_neutral.faces, iuv_pred,
                             vert_errors_batch, imgnames, os.path.join('./notebooks/output/demo_results', dataset_name,
                                                                            options.checkpoint.split('/')[-3]), options)

            preds = coords.copy()

            scale_ = np.array([scale, scale]).transpose()

            # Transform back
            for i in range(coords.shape[0]):
                preds[i] = transform_preds(
                    coords[i], center[i], scale_[i], [img_res, img_res]
                )

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = 1.
            all_boxes[idx:idx + num_images, 5] = 1.
            image_path.extend(batch['imgname'])

            idx += num_images

        if len(options.vis_imname) > 0:
            exit()

        if args.checkpoint is None or 'model_checkpoint.pt' in args.checkpoint:
            ckp_name = 'spin_model'
        else:
            ckp_name = args.checkpoint.split('/')
            ckp_name = ckp_name[2].split('_')[1] + '_' + ckp_name[-1].split('.')[0]
        name_values, perf_indicator = dataset.evaluate(
            cfg, all_preds, options.output_dir, all_boxes, image_path, ckp_name,
            filenames, imgnums
        )

        model_name = options.regressor
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
Ejemplo n.º 10
0
    def init(self):
        
        # Create training and testing dataset
        self.train_ds = MGNDataset(self.options, split = 'train')
        self.test_ds  = MGNDataset(self.options, split = 'test')
        
        # test data loader is fixed and disable shuffle as it is unnecessary.
        self.test_data_loader = CheckpointDataLoader( self.test_ds,
                    batch_size  = self.options.batch_size,
                    num_workers = self.options.num_workers,
                    pin_memory  = self.options.pin_memory,
                    shuffle     = False)

        # Create SMPL Mesh (graph) object for GCN
        self.mesh = Mesh(self.options, self.options.num_downsampling)
        self.faces = torch.cat( self.options.batch_size * [
                                self.mesh.faces.to(self.device)[None]
                                ], dim = 0 )
        self.faces = self.faces.type(torch.LongTensor).to(self.device)
        
        # Create SMPL blending model and edges
        self.smpl = SMPL(self.options.smpl_model_path, self.device)
        # self.smplEdge = torch.Tensor(np.load(self.options.smpl_edges_path)) \
        #                 .long().to(self.device)
        
        # create SMPL+D blending model
        self.smplD = Smpl( self.options.smpl_model_path ) 
        
        # read SMPL .bj file to get uv coordinates
        _, self.smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(self.options.smpl_objfile_path)
        uv_coord[:, 1] = 1 - uv_coord[:, 1]
        expUV = uv_coord[tri_uv_ind.flatten()]
        unique, index = np.unique(self.smpl_tri_ind.flatten(), return_index = True)
        self.smpl_verts_uvs = torch.as_tensor(expUV[index,:]).float().to(self.device)
        self.smpl_tri_ind   = torch.as_tensor(self.smpl_tri_ind).to(self.device)
        
        # camera for projection 
        self.perspCam = perspCamera() 

        # mean and std of displacements
        self.dispPara = \
            torch.Tensor(np.load(self.options.MGN_offsMeanStd_path)).to(self.device)
        
        # load average pose and shape and convert to camera coodinate;
        # avg pose is decided by the image id we use for training (0-11) 
        avgPose_objCoord = np.load(self.options.MGN_avgPose_path)
        avgPose_objCoord[:3] = rotationMatrix_to_axisAngle(    # for 0,6, front only
            torch.tensor([[[1,  0, 0],
                           [0, -1, 0],
                           [0,  0,-1]]]))
        self.avgPose = \
            axisAngle_to_Rot6d(
                torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3)
                    ).reshape(1, -1).to(self.device)
        self.avgBeta = \
            torch.Tensor(
                np.load(self.options.MGN_avgBeta_path)[None]).to(self.device)
        self.avgCam  = torch.Tensor([1.2755, 0, 0])[None].to(self.device)    # 1.2755 is for our settings
        
        self.model = frameVIBE(
            self.options.smpl_model_path, 
            self.mesh,                
            self.avgPose,
            self.avgBeta,
            self.avgCam,
            self.options.num_channels,
            self.options.num_layers ,  
            self.smpl_verts_uvs,
            self.smpl_tri_ind
            ).to(self.device)
            
        # Setup a optimizer for models
        self.optimizer = torch.optim.Adam(
            params=list(self.model.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)
        
        self.criterion = VIBELoss(
            e_loss_weight=50,         # for kp 2d, help to estimate camera, global orientation
            e_3d_loss_weight=1,       # for kp 3d, bvt
            e_pose_loss_weight=10,     # for body pose parameters
            e_shape_loss_weight=1,   # for body shape parameters
            e_disp_loss_weight=1,   # for displacements 
            e_tex_loss_weight=1,       # for uv image 
            d_motion_loss_weight=0
            )
        
        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {self.options.model: self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
Ejemplo n.º 11
0
class trainer(BaseTrain):
    
    def init(self):
        
        # Create training and testing dataset
        self.train_ds = MGNDataset(self.options, split = 'train')
        self.test_ds  = MGNDataset(self.options, split = 'test')
        
        # test data loader is fixed and disable shuffle as it is unnecessary.
        self.test_data_loader = CheckpointDataLoader( self.test_ds,
                    batch_size  = self.options.batch_size,
                    num_workers = self.options.num_workers,
                    pin_memory  = self.options.pin_memory,
                    shuffle     = False)

        # Create SMPL Mesh (graph) object for GCN
        self.mesh = Mesh(self.options, self.options.num_downsampling)
        self.faces = torch.cat( self.options.batch_size * [
                                self.mesh.faces.to(self.device)[None]
                                ], dim = 0 )
        self.faces = self.faces.type(torch.LongTensor).to(self.device)
        
        # Create SMPL blending model and edges
        self.smpl = SMPL(self.options.smpl_model_path, self.device)
        # self.smplEdge = torch.Tensor(np.load(self.options.smpl_edges_path)) \
        #                 .long().to(self.device)
        
        # create SMPL+D blending model
        self.smplD = Smpl( self.options.smpl_model_path ) 
        
        # read SMPL .bj file to get uv coordinates
        _, self.smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(self.options.smpl_objfile_path)
        uv_coord[:, 1] = 1 - uv_coord[:, 1]
        expUV = uv_coord[tri_uv_ind.flatten()]
        unique, index = np.unique(self.smpl_tri_ind.flatten(), return_index = True)
        self.smpl_verts_uvs = torch.as_tensor(expUV[index,:]).float().to(self.device)
        self.smpl_tri_ind   = torch.as_tensor(self.smpl_tri_ind).to(self.device)
        
        # camera for projection 
        self.perspCam = perspCamera() 

        # mean and std of displacements
        self.dispPara = \
            torch.Tensor(np.load(self.options.MGN_offsMeanStd_path)).to(self.device)
        
        # load average pose and shape and convert to camera coodinate;
        # avg pose is decided by the image id we use for training (0-11) 
        avgPose_objCoord = np.load(self.options.MGN_avgPose_path)
        avgPose_objCoord[:3] = rotationMatrix_to_axisAngle(    # for 0,6, front only
            torch.tensor([[[1,  0, 0],
                           [0, -1, 0],
                           [0,  0,-1]]]))
        self.avgPose = \
            axisAngle_to_Rot6d(
                torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3)
                    ).reshape(1, -1).to(self.device)
        self.avgBeta = \
            torch.Tensor(
                np.load(self.options.MGN_avgBeta_path)[None]).to(self.device)
        self.avgCam  = torch.Tensor([1.2755, 0, 0])[None].to(self.device)    # 1.2755 is for our settings
        
        self.model = frameVIBE(
            self.options.smpl_model_path, 
            self.mesh,                
            self.avgPose,
            self.avgBeta,
            self.avgCam,
            self.options.num_channels,
            self.options.num_layers ,  
            self.smpl_verts_uvs,
            self.smpl_tri_ind
            ).to(self.device)
            
        # Setup a optimizer for models
        self.optimizer = torch.optim.Adam(
            params=list(self.model.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)
        
        self.criterion = VIBELoss(
            e_loss_weight=50,         # for kp 2d, help to estimate camera, global orientation
            e_3d_loss_weight=1,       # for kp 3d, bvt
            e_pose_loss_weight=10,     # for body pose parameters
            e_shape_loss_weight=1,   # for body shape parameters
            e_disp_loss_weight=1,   # for displacements 
            e_tex_loss_weight=1,       # for uv image 
            d_motion_loss_weight=0
            )
        
        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {self.options.model: self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        
    def convert_GT(self, input_batch):     
        
        smplGT = torch.cat([
            # GTcamera is not used, here is a placeholder
            torch.zeros([self.options.batch_size,3]).to(self.device),
            # the global rotation is incorrect as it is under the original 
            # coordinate system. The rest and betas are fine.
            rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(-1, 72),   # 24 * 6 = 144
            input_batch['GTsmplParas']['betas']],
            dim = 1).float()
        
        # vertices in the original coordinates
        vertices = self.smpl(
            pose = rot6d_to_axisAngle(input_batch['GTsmplParas']['pose']).reshape(self.options.batch_size, 72),
            beta = input_batch['GTsmplParas']['betas'].float(),
            trans = input_batch['GTsmplParas']['trans']
            )
        
        # get joints in 3d and 2d in the current camera coordiante
        joints_3d = self.smpl.get_joints(vertices.float())
        joints_2d, _, joints_3d = self.perspCam(
            fx = input_batch['GTcamera']['f_rot'][:,0,0], 
            fy = input_batch['GTcamera']['f_rot'][:,0,0], 
            cx = 112, 
            cy = 112, 
            rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(),  
            translation = input_batch['GTcamera']['t'][:,None,:].float(), 
            points = joints_3d,
            visibilityOn = False,
            output3d = True
            )
        joints_3d = (joints_3d - input_batch['GTcamera']['t'][:,None,:]).float() # remove shifts
        
        # visind = 1
        # img = torch.zeros([224,224])
        # joints_2d[visind][joints_2d[visind].round() >= 224] = 223 
        # joints_2d[visind][joints_2d[visind].round() < 0] = 0 
        # img[joints_2d[visind][:,1].round().long(), joints_2d[visind][:,0].round().long()] = 1
        # plt.imshow(img)
        # plt.imshow(input_batch['img'][visind].cpu().permute(1,2,0))
        
        # convert to [-1, +1]
        joints_2d = (torch.cat(joints_2d, dim=0).reshape([self.options.batch_size, 24, 2]) - 112)/112
        joints_2d = joints_2d.float()
        
        # convert vertices to current camera coordiantes
        _,_,vertices = self.perspCam(
            fx = input_batch['GTcamera']['f_rot'][:,0,0], 
            fy = input_batch['GTcamera']['f_rot'][:,0,0], 
            cx = 112, 
            cy = 112, 
            rotation = rot6d_to_rotmat(input_batch['GTcamera']['f_rot'][:,0,1:]).float(),  
            translation = input_batch['GTcamera']['t'][:,None,:].float(), 
            points = vertices,
            visibilityOn = False,
            output3d = True
            )
        vertices = (vertices - input_batch['GTcamera']['t'][:,None,:]).float()
        
        # prove the correctness of coord sys,
        # points = vertices (vertices before shift)
        # localcam = perspCamera(smpl_obj = False)    # disable additional trans
        # img_points, _ = localcam(
        #     fx = input_batch['GTcamera']['f_rot'][:,0,0], 
        #     fy = input_batch['GTcamera']['f_rot'][:,0,0], 
        #     cx = 112, 
        #     cy = 112, 
        #     rotation = torch.eye(3)[None].repeat_interleave(points.shape[0], dim = 0).to('cuda'),  
        #     translation = torch.zeros([points.shape[0],1,3]).to('cuda'), 
        #     points = points.float(),
        #     visibilityOn = False,
        #     output3d = False
        #     )
        # img = torch.zeros([224,224])
        # img_points[visind][img_points[visind].round() >= 224] = 223 
        # img_points[visind][img_points[visind].round() < 0] = 0 
        # img[img_points[visind][:,1].round().long(), img_points[visind][:,0].round().long()] = 1
        # plt.imshow(img)
                
        # prove the correctness of displacements
        # import open3d as o3d
        # points = vertices
        # ind = 0
        # o3dMesh = o3d.geometry.TriangleMesh()
        # o3dMesh.vertices = o3d.utility.Vector3dVector(points[ind].cpu())
        # o3dMesh.triangles= o3d.utility.Vector3iVector(self.faces[0].cpu())
        # o3dMesh.compute_vertex_normals()     
        # o3d.visualization.draw_geometries([o3dMesh])
    
        # self.renderer = renderer(batch_size = 1)
        # images = self.renderer(
        #     verts = vertices[0][None],
        #     faces = self.faces[0][None],
        #     verts_uvs = self.smpl_verts_uvs[None].repeat_interleave(2, dim=0)[0][None],
        #     faces_uvs = self.smpl_tri_ind[None].repeat_interleave(2, dim=0)[0][None],
        #     tex_image = tex[None],
        #     R = torch.eye(3)[None].repeat_interleave(2, dim = 0).to('cuda')[0][None],
        #     T = input_batch['GTcamera']['t'].float()[0][None],
        #     f = input_batch['GTcamera']['f_rot'][:,0,0][:,None].float()[0][None],
        #     C = torch.ones([2,2]).to('cuda')[0][None]*112,
        #     imgres = 224
        # )
        
        GT = {'img' : input_batch['img'].float(),
              'img_orig': input_batch['img_orig'].float(),
              'imgname' : input_batch['imgname'],
              
              'camera': input_batch['GTcamera'],                   # wcd 
              'theta': smplGT.float(),        
              
              # joints_2d is col, row == x, y; joints_3d is x,y,z
              'target_2d': joints_2d.float(),
              'target_3d': joints_3d.float(),
              'target_bvt': vertices.float(),   # body vertices
              'target_dp': input_batch['GToffsets_t']['offsets'].float(),    # smpl cd t-pose
              
              'target_uv': input_batch['GTtextureMap'].float()
              }
            
        return GT
    
    def train_step(self, input_batch):
        """Training step."""
        self.model.train()
        
        # prepare data
        GT = self.convert_GT(input_batch)    
  
        # forward pass
        pred = self.model(GT['img'], GT['img_orig'])
        
        # loss
        gen_loss, loss_dict = self.criterion(
            generator_outputs=pred,
            data_2d = GT['target_2d'],
            data_3d = GT,
            )
        # print(gen_loss)
        out_args = loss_dict
        out_args['loss'] = gen_loss
        out_args['prediction'] = pred 
        
        # save one training sample for vis
        if (self.step_count+1) % self.options.summary_steps == 0:
            with torch.no_grad():
                self.save_sample(GT, pred[0], saveTo = 'train')

        return out_args

    def test(self):
        """"Testing process"""
        self.model.eval()    
        
        test_loss, test_tex_loss = 0, 0
        test_loss_pose, test_loss_shape = 0, 0
        test_loss_kp_2d, test_loss_kp_3d = 0, 0
        test_loss_dsp_3d, test_loss_bvt_3d = 0, 0
        for step, batch in enumerate(tqdm(self.test_data_loader, desc='Test',
                                          total=len(self.test_ds) // self.options.batch_size,
                                          initial=self.test_data_loader.checkpoint_batch_idx),
                                     self.test_data_loader.checkpoint_batch_idx):
            # convert data devices
            batch_toDEV = {}
            for key, val in batch.items():
                if isinstance(val, torch.Tensor):
                    batch_toDEV[key] = val.to(self.device)
                else:
                        batch_toDEV[key] = val
                if isinstance(val, dict):
                    batch_toDEV[key] = {}
                    for k, v in val.items():
                        if isinstance(v, torch.Tensor):
                            batch_toDEV[key][k] = v.to(self.device)
            # prepare data
            GT = self.convert_GT(batch_toDEV)

            with torch.no_grad():    # disable grad
                # forward pass
                pred = self.model(GT['img'], GT['img_orig'])
                
                # loss
                gen_loss, loss_dict = self.criterion(
                    generator_outputs=pred,
                    data_2d = GT['target_2d'],
                    data_3d = GT,
                    )
                # save for comparison
                if step == 0:
                    self.save_sample(GT, pred[0])
                
            test_loss += gen_loss
            test_loss_pose  += loss_dict['loss_pose']
            test_loss_shape += loss_dict['loss_shape']
            test_loss_kp_2d += loss_dict['loss_kp_2d']
            test_loss_kp_3d += loss_dict['loss_kp_3d']
            test_loss_bvt_3d += loss_dict['loss_bvt_3d']
            test_loss_dsp_3d += loss_dict['loss_dsp_3d']
            test_tex_loss += loss_dict['loss_tex']
                        
        test_loss = test_loss/len(self.test_data_loader)
        test_loss_pose  = test_loss_pose/len(self.test_data_loader)
        test_loss_shape = test_loss_shape/len(self.test_data_loader)
        test_loss_kp_2d = test_loss_kp_2d/len(self.test_data_loader)
        test_loss_kp_3d = test_loss_kp_3d/len(self.test_data_loader)
        test_loss_bvt_3d = test_loss_bvt_3d/len(self.test_data_loader)
        test_loss_dsp_3d = test_loss_dsp_3d/len(self.test_data_loader)
        test_tex_loss = test_tex_loss/len(self.test_data_loader)
        
        lossSummary = {'test_loss': test_loss, 
                       'test_loss_pose' : test_loss_pose,
                       'test_loss_shape' : test_loss_shape,
                       'test_loss_kp_2d' : test_loss_kp_2d,
                       'test_loss_kp_3d' : test_loss_kp_3d,
                       'test_loss_bvt_3d': test_loss_bvt_3d,
                       'test_loss_dsp_3d': test_loss_dsp_3d,
                       'test_tex_loss': test_tex_loss
                       }
        self.test_summaries(lossSummary)
        
        if test_loss < self.last_test_loss:
            self.last_test_loss = test_loss
            self.saver.save_checkpoint(self.models_dict, 
                                       self.optimizers_dict, 
                                       0, 
                                       step+1, 
                                       self.options.batch_size, 
                                       self.test_data_loader.sampler.dataset_perm, 
                                       self.step_count) 
            tqdm.write('Better test checkpoint saved')
        
    def save_sample(self, data, prediction, ind = 0, saveTo = 'test'):
        """Saving a sample for visualization and comparison"""
        
        assert saveTo in ('test', 'train'), 'save to train or test folder'
        
        folder = pjn(self.options.summary_dir, '%s/'%(saveTo))
        _input = pjn(folder, 'input_image.png')
        # batchs = self.options.batch_size
        
        # save the input image, if not saved
        if not isfile(_input):
            plt.imsave(_input, data['img'][ind].cpu().permute(1,2,0).clamp(0,1).numpy())
            
        # overlap the prediction to the real image; as the mesh .obj has diff 
        # ft/uv coord from MPI lib and MGN, we flip the predicted texture.
        save_renderer = simple_renderer(batch_size = 1)
        predTrans = torch.stack(
            [prediction['theta'][ind, 1],
             prediction['theta'][ind, 2],
             2 * 1000. / (224. * prediction['theta'][ind, 0] + 1e-9)], dim=-1)
        tex = prediction['tex_image'][ind].flip(dims=(0,))[None]
        pred_img = save_renderer(
            verts = prediction['verts'][ind][None],
            faces = self.faces[ind][None],
            verts_uvs = self.smpl_verts_uvs[None],
            faces_uvs = self.smpl_tri_ind[None],
            tex_image = tex,
            R = torch.eye(3)[None].to('cuda'),
            T = predTrans,
            f = torch.ones([1,1]).to('cuda')*1000,
            C = torch.ones([1,2]).to('cuda')*112,
            imgres = 224)
        
        overlayimg = 0.9*pred_img[0,:,:,:3] + 0.1*data['img'][ind].permute(1,2,0)
        save_path = pjn(folder,'overlay_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (overlayimg.clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'tex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (pred_img[0,:,:,:3].clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'unwarptex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (prediction['unwarp_tex'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8'))
        save_path = pjn(folder,'predtex_%s_iters%d.png'%(data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        plt.imsave(save_path, (prediction['tex_image'][ind].clamp(0, 1)*255).cpu().numpy().astype('uint8'))


        # create predicted posed undressed body vetirces       
        offPred_t  = (prediction['verts_disp'][ind]*self.dispPara[1]+self.dispPara[0]).cpu()[None,:] 
        predDressbody = create_smplD_psbody(
            self.smplD, offPred_t, 
            prediction['theta'][ind][3:75][None].cpu(), 
            prediction['theta'][ind][75:][None].cpu(), 
            0, 
            rtnMesh=True)[1]
                            
        # Create meshes and save 
        savepath = pjn(folder,'%s_iters%d.obj'%\
                       (data['imgname'][ind].split('/')[-1][:-4], self.step_count))
        predDressbody.write_obj(savepath)
Ejemplo n.º 12
0
def run_evaluation(hmr_model, dataset, eval_size, batch_size=32, num_workers=32, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # focal length
    focal_length = constants.FOCAL_LENGTH

    # Transfer hmr_model to the GPU
    hmr_model.to(device)

    # Load SMPL hmr_model
    smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()

    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros((len(dataset)))
    recon_err = np.zeros((len(dataset)))

    joint_mapper_h36m = constants.H36M_TO_J14

    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)

        gender = batch['gender'].to(device)
        images = batch['img_up']
        curr_batch_size = images.shape[0]

        with torch.no_grad():
            images = images.to(device)
            pred_rotmat, pred_betas, pred_camera, _ = hmr_model(images, scale=size_to_scale(eval_size))
            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:, 1:],
                                       global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(device)

            gt_vertices = smpl_male(global_orient=gt_pose[:, :3], body_pose=gt_pose[:, 3:], betas=gt_betas).vertices
            gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3], body_pose=gt_pose[:, 3:],
                                             betas=gt_betas).vertices
            gt_vertices[gender == 1, :, :] = gt_vertices_female[gender == 1, :, :]
            gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
            gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
            gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
            gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None)
            recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            # Print intermediate results during evaluation
            if step % log_freq == log_freq - 1:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' + str(1000 * recon_err[:step * batch_size].mean()))

    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    print('MPJPE: {}'.format(1000 * mpjpe.mean()))
    print('Reconstruction Error: {}'.format(1000 * recon_err.mean()))
    print()
Ejemplo n.º 13
0
def main():
    """Main function"""
    args = parse_args()
    args.batch_size = 1

    cfg_from_file(args.cfg_file)

    cfg.DANET.REFINEMENT = EasyDict(cfg.DANET.REFINEMENT)
    cfg.MSRES_MODEL.EXTRA = EasyDict(cfg.MSRES_MODEL.EXTRA)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    if cfg.DANET.SMPL_MODEL_TYPE == 'male':
        smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                         gender='male',
                         create_transl=False).to(device)
        smpl = smpl_male
    elif cfg.DANET.SMPL_MODEL_TYPE == 'neutral':
        smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                            create_transl=False).to(device)
        smpl = smpl_neutral
    elif cfg.DANET.SMPL_MODEL_TYPE == 'female':
        smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                           gender='female',
                           create_transl=False).to(device)
        smpl = smpl_female

    if args.use_opendr:
        from utils.renderer import opendr_render
        dr_render = opendr_render()

    # IUV renderer
    iuv_renderer = IUV_Renderer()

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    ### Model ###
    model = DaNet(args, path_config.SMPL_MEAN_PARAMS,
                  pretrained=False).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.eval()

    img_path_list = [
        os.path.join(args.img_dir, name) for name in os.listdir(args.img_dir)
        if name.endswith('.jpg')
    ]
    for i, path in enumerate(img_path_list):

        image = Image.open(path).convert('RGB')
        img_id = path.split('/')[-1][:-4]

        image_tensor = torchvision.transforms.ToTensor()(image).unsqueeze(
            0).cuda()

        # run inference
        pred_dict = model.infer_net(image_tensor)
        para_pred = pred_dict['para']
        camera_pred = para_pred[:, 0:3].contiguous()
        betas_pred = para_pred[:, 3:13].contiguous()
        rotmat_pred = para_pred[:, 13:].contiguous().view(-1, 24, 3, 3)

        # input image
        image_np = image_tensor[0].cpu().numpy()
        image_np = np.transpose(image_np, (1, 2, 0))

        ones_np = np.ones(image_np.shape[:2]) * 255
        ones_np = ones_np[:, :, None]

        image_in_rgba = np.concatenate((image_np, ones_np), axis=2)

        # estimated global IUV
        global_iuv = iuv_map2img(
            *pred_dict['visualization']['iuv_pred'])[0].cpu().numpy()
        global_iuv = np.transpose(global_iuv, (1, 2, 0))
        global_iuv = resize(global_iuv, image_np.shape[:2])
        global_iuv_rgba = np.concatenate((global_iuv, ones_np), axis=2)

        # estimated patial IUV
        part_iuv_pred = pred_dict['visualization']['part_iuv_pred'][0]
        p_iuv_vis = []
        for i in range(part_iuv_pred.size(0)):
            p_u_vis, p_v_vis, p_i_vis = [
                part_iuv_pred[i, iuv].unsqueeze(0) for iuv in range(3)
            ]
            if p_u_vis.size(1) == 25:
                p_iuv_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(),
                                          p_i_vis.detach())
            else:
                p_iuv_vis_i = iuv_map2img(p_u_vis.detach(),
                                          p_v_vis.detach(),
                                          p_i_vis.detach(),
                                          ind_mapping=[0] +
                                          model.img2iuv.dp2smpl_mapping[i])
            p_iuv_vis.append(p_iuv_vis_i)
        part_iuv = torch.cat(p_iuv_vis, dim=0)
        part_iuv = make_grid(part_iuv, nrow=6, padding=0).cpu().numpy()
        part_iuv = np.transpose(part_iuv, (1, 2, 0))
        part_iuv_rgba = np.concatenate(
            (part_iuv, np.ones(part_iuv.shape[:2])[:, :, None] * 255), axis=2)

        # rendered IUV of the predicted SMPL model
        smpl_output = smpl(betas=betas_pred,
                           body_pose=rotmat_pred[:, 1:],
                           global_orient=rotmat_pred[:, 0].unsqueeze(1),
                           pose2rot=False)
        verts_pred = smpl_output.vertices
        render_iuv = iuv_renderer.verts2uvimg(verts_pred, camera_pred)
        render_iuv = render_iuv[0].cpu().numpy()

        render_iuv = np.transpose(render_iuv, (1, 2, 0))
        render_iuv = resize(render_iuv, image_np.shape[:2])

        img_render_iuv = image_np.copy()
        img_render_iuv[render_iuv > 0] = render_iuv[render_iuv > 0]

        img_render_iuv_rgba = np.concatenate((img_render_iuv, ones_np), axis=2)

        img_vis_list = [
            image_in_rgba, global_iuv_rgba, part_iuv_rgba, img_render_iuv_rgba
        ]

        if args.use_opendr:
            # visualize the predicted SMPL model using the opendr renderer
            K = iuv_renderer.K[0].cpu().numpy()
            _, _, img_smpl, smpl_rgba = dr_render.render(
                image_tensor[0].cpu().numpy(), camera_pred[0].cpu().numpy(), K,
                verts_pred.cpu().numpy(), smpl_neutral.faces)

            img_smpl_rgba = np.concatenate((img_smpl, ones_np), axis=2)
            img_vis_list.extend([img_smpl_rgba, smpl_rgba])

        img_vis = np.concatenate(img_vis_list, axis=1)
        img_vis[img_vis < 0.0] = 0.0
        img_vis[img_vis > 1.0] = 1.0
        imsave(os.path.join(args.out_dir, img_id + '_result.png'), img_vis)

    print('Demo results have been saved in {}.'.format(args.out_dir))
Ejemplo n.º 14
0
def run_evaluation(model,
                   dataset,
                   result_file,
                   batch_size=32,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   options=None):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    num_joints = 17

    num_samples = len(dataset)
    print('dataset length: {}'.format(num_samples))
    all_preds = np.zeros((num_samples, num_joints, 3), dtype=np.float32)
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        end = time.time()

        for step, batch in enumerate(
                tqdm(data_loader, desc='Eval', total=len(data_loader))):
            images = batch['img'].to(device)
            scale = batch['scale'].numpy()
            center = batch['center'].numpy()

            num_images = images.size(0)

            gt_keypoints_2d = batch['keypoints']  # 2D keypoints
            # De-normalize 2D keypoints from [-1,1] to pixel space
            gt_keypoints_2d_orig = gt_keypoints_2d.clone()
            gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * (
                gt_keypoints_2d_orig[:, :, :-1] + 1)

            if options.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
            elif options.regressor == 'danet':
                danet_pred_dict = model.infer_net(images)
                para_pred = danet_pred_dict['para']
                pred_camera = para_pred[:, 0:3].contiguous()
                pred_betas = para_pred[:, 3:13].contiguous()
                pred_rotmat = para_pred[:, 13:].contiguous().view(-1, 24, 3, 3)

            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)

            # pred_vertices = pred_output.vertices
            pred_J24 = pred_output.joints[:, -24:]
            pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO]

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 *
                constants.FOCAL_LENGTH / (img_res * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(len(pred_JCOCO),
                                        2,
                                        device=pred_camera.device)
            pred_keypoints_2d = perspective_projection(
                pred_JCOCO,
                rotation=torch.eye(
                    3, device=pred_camera.device).unsqueeze(0).expand(
                        len(pred_JCOCO), -1, -1),
                translation=pred_cam_t,
                focal_length=constants.FOCAL_LENGTH,
                camera_center=camera_center)

            coords = pred_keypoints_2d + (img_res / 2.)
            coords = coords.cpu().numpy()
            # Normalize keypoints to [-1,1]
            # pred_keypoints_2d = pred_keypoints_2d / (img_res / 2.)

            gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants.
                                                              J24_TO_JCOCO]

            preds = coords.copy()

            scale_ = np.array([scale, scale]).transpose()

            # Transform back
            for i in range(coords.shape[0]):
                preds[i] = transform_preds(coords[i], center[i], scale_[i],
                                           [img_res, img_res])

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = 1.
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = center[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = scale_[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(scale_ * 200, 1)
            all_boxes[idx:idx + num_images, 5] = 1.
            image_path.extend(batch['imgname'])

            idx += num_images

        ckp_name = options.regressor
        name_values, perf_indicator = dataset.evaluate(all_preds,
                                                       options.output_dir,
                                                       all_boxes, image_path,
                                                       ckp_name, filenames,
                                                       imgnums)

        model_name = options.regressor
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
Ejemplo n.º 15
0
    img = img.astype(np.float32) / 255.
    img = torch.from_numpy(img).permute(2, 0, 1)
    norm_img = normalize_img(img.clone())[None]
    return img, norm_img


if __name__ == '__main__':
    args = parser.parse_args()
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Load trained model
    model = hmr(config.SMPL_MEAN_PARAMS).to(device)
    checkpoint = torch.load(args.trained_model)
    model.load_state_dict(checkpoint['model'], strict=False)
    smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                create_transl=False).to(device)
    model.eval()
    # Generate rendered image
    renderer = Renderer(focal_length=constants.FOCAL_LENGTH,
                        img_res=constants.IMG_RES,
                        faces=smpl.faces)
    # Processs the image and predict the parameters
    img, norm_img = process_image(args.test_image,
                                  args.bbox,
                                  input_res=constants.IMG_RES)
    with torch.no_grad():
        pred_rotmat, pred_betas, pred_camera = model(norm_img.to(device))
        pred_output = smpl(betas=pred_betas,
                           body_pose=pred_rotmat[:, 1:],
                           global_orient=pred_rotmat[:, 0].unsqueeze(1),
                           pose2rot=False)
Ejemplo n.º 16
0
def run_evaluation(model,
                   dataset_name,
                   dataset,
                   mesh,
                   batch_size=32,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    renderer = PartRenderer()

    # Create SMPL model
    smpl = SMPL().cuda()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M)).float()

    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Transfer model to the GPU
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    eval_pose = False
    eval_shape = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2':
        eval_pose = True
    elif dataset_name == 'up-3d':
        eval_shape = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = cfg.DATASET_FOLDERS['upi-s1h']

    # Iterate over the entire dataset
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl(gt_pose, gt_betas)
        images = batch['img'].to(device)
        curr_batch_size = images.shape[0]

        # Run inference
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(
                images)

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)

            # Get 14 ground truth joints
            gt_keypoints_3d = batch['pose_3d'].cuda()
            gt_keypoints_3d = gt_keypoints_3d[:, cfg.J24_TO_J14, :-1]

            # Get 14 predicted joints from the non-parametic mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, cfg.H36M_TO_J14, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis
            # Get 14 predicted joints from the SMPL mesh
            pred_keypoints_3d_smpl = torch.matmul(J_regressor_batch,
                                                  pred_vertices_smpl)
            pred_pelvis_smpl = pred_keypoints_3d_smpl[:, [0], :].clone()
            pred_keypoints_3d_smpl = pred_keypoints_3d_smpl[:,
                                                            cfg.H36M_TO_J14, :]
            pred_keypoints_3d_smpl = pred_keypoints_3d_smpl - pred_pelvis_smpl

            # Compute error metrics

            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            error_smpl = torch.sqrt(
                ((pred_keypoints_3d_smpl -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size +
                  curr_batch_size] = error
            mpjpe_smpl[step * batch_size:step * batch_size +
                       curr_batch_size] = error_smpl

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(),
                                           gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)
            r_error_smpl = reconstruction_error(
                pred_keypoints_3d_smpl.cpu().numpy(),
                gt_keypoints_3d.cpu().numpy(),
                reduction=None)
            recon_err[step * batch_size:step * batch_size +
                      curr_batch_size] = r_error
            recon_err_smpl[step * batch_size:step * batch_size +
                           curr_batch_size] = r_error_smpl

        # Shape evaluation (Mean per-vertex error)
        if eval_shape:
            se = torch.sqrt(
                ((pred_vertices -
                  gt_vertices)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            se_smpl = torch.sqrt(
                ((pred_vertices_smpl -
                  gt_vertices)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            shape_err[step * batch_size:step * batch_size +
                      curr_batch_size] = se
            shape_err_smpl[step * batch_size:step * batch_size +
                           curr_batch_size] = se_smpl

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i],
                                   orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(
                    os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8),
                                    center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(
                    os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE (NonParam): ' +
                      str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error (NonParam): ' +
                      str(1000 * recon_err[:step * batch_size].mean()))
                print('MPJPE (Param): ' +
                      str(1000 * mpjpe_smpl[:step * batch_size].mean()))
                print('Reconstruction Error (Param): ' +
                      str(1000 * recon_err_smpl[:step * batch_size].mean()))
                print()
            if eval_shape:
                print('Shape Error (NonParam): ' +
                      str(1000 * shape_err[:step * batch_size].mean()))
                print('Shape Error (Param): ' +
                      str(1000 * shape_err_smpl[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5,
                                                   6]].mean())
                print()

    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE (NonParam): ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error (NonParam): ' +
              str(1000 * recon_err.mean()))
        print('MPJPE (Param): ' + str(1000 * mpjpe_smpl.mean()))
        print('Reconstruction Error (Param): ' +
              str(1000 * recon_err_smpl.mean()))
        print()
    if eval_shape:
        print('Shape Error (NonParam): ' + str(1000 * shape_err.mean()))
        print('Shape Error (Param): ' + str(1000 * shape_err_smpl.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
        print()
Ejemplo n.º 17
0
    args = parser.parse_args()
    img_path = args.img_path
    checkpoint_path = args.checkpoint

    normalize_img = Normalize(mean=constants.IMG_NORM_MEAN,
                              std=constants.IMG_NORM_STD)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    hmr_model = hmr(config.SMPL_MEAN_PARAMS)
    checkpoint = torch.load(checkpoint_path)
    hmr_model.load_state_dict(checkpoint, strict=False)
    hmr_model.eval()
    hmr_model.to(device)

    smpl_neutral = SMPL(config.SMPL_MODEL_DIR, create_transl=False).to(device)
    img_renderer = Renderer(focal_length=constants.FOCAL_LENGTH,
                            img_res=constants.IMG_RES,
                            faces=smpl_neutral.faces)

    img = imageio.imread(img_path)
    im_size = img.shape[0]
    im_scale = size_to_scale(im_size)
    img_up = scipy.misc.imresize(img, [224, 224])
    img_up = np.transpose(img_up.astype('float32'), (2, 0, 1)) / 255.0
    img_up = normalize_img(torch.from_numpy(img_up).float())
    images = img_up[None].to(device)

    with torch.no_grad():
        pred_rotmat, pred_betas, pred_camera, _ = hmr_model(images,
                                                            scale=im_scale)
Ejemplo n.º 18
0
class Trainer(BaseTrainer):
    """Trainer object.
    Inherits from BaseTrainer that sets up logging, saving/restoring checkpoints etc.
    """
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset, self.options)

        # create Mesh object
        self.mesh = Mesh()
        self.faces = self.mesh.faces.to(self.device)

        # create GraphCNN
        self.graph_cnn = GraphCNN(self.mesh.adjmat,
                                  self.mesh.ref_vertices.t(),
                                  num_channels=self.options.num_channels,
                                  num_layers=self.options.num_layers).to(
                                      self.device)

        # SMPL Parameter regressor
        self.smpl_param_regressor = SMPLParamRegressor().to(self.device)

        # Setup a joint optimizer for the 2 models
        self.optimizer = torch.optim.Adam(
            params=list(self.graph_cnn.parameters()) +
            list(self.smpl_param_regressor.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)

        # SMPL model
        self.smpl = SMPL().to(self.device)

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {
            'graph_cnn': self.graph_cnn,
            'smpl_param_regressor': self.smpl_param_regressor
        }
        self.optimizers_dict = {'optimizer': self.optimizer}

        # Renderer for visualization
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

    def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d):
        """Compute 2D reprojection loss on the keypoints.
        The confidence is binary and indicates whether the keypoints exist or not.
        The available keypoints are different for each dataset.
        """
        conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
        loss = (conf * self.criterion_keypoints(
            pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
        return loss

    def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d,
                         has_pose_3d):
        """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available.
        The loss is weighted by the confidence
        """
        conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
        gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
        conf = conf[has_pose_3d == 1]
        pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
        if len(gt_keypoints_3d) > 0:
            gt_pelvis = (gt_keypoints_3d[:, 2, :] +
                         gt_keypoints_3d[:, 3, :]) / 2
            gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
            pred_pelvis = (pred_keypoints_3d[:, 2, :] +
                           pred_keypoints_3d[:, 3, :]) / 2
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
            return (conf * self.criterion_keypoints(pred_keypoints_3d,
                                                    gt_keypoints_3d)).mean()
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def shape_loss(self, pred_vertices, gt_vertices, has_smpl):
        """Compute per-vertex loss on the shape for the examples that SMPL annotations are available."""
        pred_vertices_with_shape = pred_vertices[has_smpl == 1]
        gt_vertices_with_shape = gt_vertices[has_smpl == 1]
        if len(gt_vertices_with_shape) > 0:
            return self.criterion_shape(pred_vertices_with_shape,
                                        gt_vertices_with_shape)
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
                    has_smpl):
        """Compute SMPL parameter loss for the examples that SMPL annotations are available."""
        pred_rotmat_valid = pred_rotmat[has_smpl == 1].view(-1, 3, 3)
        gt_rotmat_valid = rodrigues(gt_pose[has_smpl == 1].view(-1, 3))
        pred_betas_valid = pred_betas[has_smpl == 1]
        gt_betas_valid = gt_betas[has_smpl == 1]
        if len(pred_rotmat_valid) > 0:
            loss_regr_pose = self.criterion_regr(pred_rotmat_valid,
                                                 gt_rotmat_valid)
            loss_regr_betas = self.criterion_regr(pred_betas_valid,
                                                  gt_betas_valid)
        else:
            loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
            loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
        return loss_regr_pose, loss_regr_betas

    def train_step(self, input_batch):
        """Training step."""
        self.graph_cnn.train()
        self.smpl_param_regressor.train()

        # Grab data from the batch
        gt_keypoints_2d = input_batch['keypoints']
        gt_keypoints_3d = input_batch['pose_3d']
        gt_pose = input_batch['pose']
        gt_betas = input_batch['betas']
        has_smpl = input_batch['has_smpl']
        has_pose_3d = input_batch['has_pose_3d']
        images = input_batch['img']

        # Render vertices using SMPL parameters
        gt_vertices = self.smpl(gt_pose, gt_betas)
        batch_size = gt_vertices.shape[0]

        # Feed image in the GraphCNN
        # Returns subsampled mesh and camera parameters
        pred_vertices_sub, pred_camera = self.graph_cnn(images)

        # Upsample mesh in the original size
        pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2))

        # Prepare input for SMPL Parameter regressor
        # The input is the predicted and template vertices subsampled by a factor of 4
        # Notice that we detach the GraphCNN
        x = pred_vertices_sub.transpose(1, 2).detach()
        x = torch.cat(
            [x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)],
            dim=-1)

        # Estimate SMPL parameters and render vertices
        pred_rotmat, pred_shape = self.smpl_param_regressor(x)
        pred_vertices_smpl = self.smpl(pred_rotmat, pred_shape)

        # Get 3D and projected 2D keypoints from the regressed shape
        pred_keypoints_3d = self.smpl.get_joints(pred_vertices)
        pred_keypoints_2d = orthographic_projection(pred_keypoints_3d,
                                                    pred_camera)[:, :, :2]
        pred_keypoints_3d_smpl = self.smpl.get_joints(pred_vertices_smpl)
        pred_keypoints_2d_smpl = orthographic_projection(
            pred_keypoints_3d_smpl, pred_camera.detach())[:, :, :2]

        # Compute losses

        # GraphCNN losses
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d)
        loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d,
                                                  gt_keypoints_3d, has_pose_3d)
        loss_shape = self.shape_loss(pred_vertices, gt_vertices, has_smpl)

        # SMPL regressor losses
        loss_keypoints_smpl = self.keypoint_loss(pred_keypoints_2d_smpl,
                                                 gt_keypoints_2d)
        loss_keypoints_3d_smpl = self.keypoint_3d_loss(pred_keypoints_3d_smpl,
                                                       gt_keypoints_3d,
                                                       has_pose_3d)
        loss_shape_smpl = self.shape_loss(pred_vertices_smpl, gt_vertices,
                                          has_smpl)
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_shape, gt_pose, gt_betas, has_smpl)

        # Add losses to compute the total loss
        loss = loss_shape_smpl + loss_keypoints_smpl + loss_keypoints_3d_smpl +\
               loss_regr_pose + 0.1 * loss_regr_betas + loss_shape + loss_keypoints + loss_keypoints_3d

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments to be used for visualization in a list
        out_args = [
            pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d,
            pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl,
            loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d,
            loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss
        ]
        out_args = [arg.detach() for arg in out_args]
        return out_args

    def train_summaries(self, input_batch, pred_vertices, pred_vertices_smpl,
                        pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl,
                        loss_shape, loss_shape_smpl, loss_keypoints,
                        loss_keypoints_smpl, loss_keypoints_3d,
                        loss_keypoints_3d_smpl, loss_regr_pose,
                        loss_regr_betas, loss):
        """Tensorboard logging."""
        gt_keypoints_2d = input_batch['keypoints'].cpu().numpy()

        rend_imgs = []
        rend_imgs_smpl = []
        batch_size = pred_vertices.shape[0]
        # Do visualization for the first 4 images of the batch
        for i in range(min(batch_size, 4)):
            img = input_batch['img_orig'][i].cpu().numpy().transpose(1, 2, 0)
            # Get LSP keypoints from the full list of keypoints
            gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp]
            pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i,
                                                                 self.to_lsp]
            pred_keypoints_2d_smpl_ = pred_keypoints_2d_smpl.cpu().numpy()[
                i, self.to_lsp]
            # Get GraphCNN and SMPL vertices for the particular example
            vertices = pred_vertices[i].cpu().numpy()
            vertices_smpl = pred_vertices_smpl[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            # Visualize reconstruction and detected pose
            rend_img = visualize_reconstruction(img, self.options.img_res,
                                                gt_keypoints_2d_, vertices,
                                                pred_keypoints_2d_, cam,
                                                self.renderer)
            rend_img_smpl = visualize_reconstruction(img, self.options.img_res,
                                                     gt_keypoints_2d_,
                                                     vertices_smpl,
                                                     pred_keypoints_2d_smpl_,
                                                     cam, self.renderer)
            rend_img = rend_img.transpose(2, 0, 1)
            rend_img_smpl = rend_img_smpl.transpose(2, 0, 1)
            rend_imgs.append(torch.from_numpy(rend_img))
            rend_imgs_smpl.append(torch.from_numpy(rend_img_smpl))
        rend_imgs = make_grid(rend_imgs, nrow=1)
        rend_imgs_smpl = make_grid(rend_imgs_smpl, nrow=1)

        # Save results in Tensorboard
        self.summary_writer.add_image('imgs', rend_imgs, self.step_count)
        self.summary_writer.add_image('imgs_smpl', rend_imgs_smpl,
                                      self.step_count)
        self.summary_writer.add_scalar('loss_shape', loss_shape,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_shape_smpl', loss_shape_smpl,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_pose', loss_regr_pose,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_betas', loss_regr_betas,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints', loss_keypoints,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_smpl',
                                       loss_keypoints_smpl, self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d', loss_keypoints_3d,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d_smpl',
                                       loss_keypoints_3d_smpl, self.step_count)
        self.summary_writer.add_scalar('loss', loss, self.step_count)
Ejemplo n.º 19
0
class Trainer(BaseTrainer):
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset,
                                       self.options,
                                       use_IUV=True)
        self.dp_res = int(self.options.img_res // (2**self.options.warp_level))

        self.CNet = DPNet(warp_lv=self.options.warp_level,
                          norm_type=self.options.norm_type).to(self.device)

        self.LNet = get_LNet(self.options).to(self.device)
        self.smpl = SMPL().to(self.device)
        self.female_smpl = SMPL(cfg.FEMALE_SMPL_FILE).to(self.device)
        self.male_smpl = SMPL(cfg.MALE_SMPL_FILE).to(self.device)

        uv_res = self.options.uv_res
        self.uv_type = self.options.uv_type
        self.sampler = Index_UV_Generator(UV_height=uv_res,
                                          UV_width=-1,
                                          uv_type=self.uv_type).to(self.device)

        weight_file = 'data/weight_p24_h{:04d}_w{:04d}_{}.npy'.format(
            uv_res, uv_res, self.uv_type)
        if not os.path.exists(weight_file):
            cal_uv_weight(self.sampler, weight_file)

        uv_weight = torch.from_numpy(np.load(weight_file)).to(
            self.device).float()
        uv_weight = uv_weight * self.sampler.mask.to(uv_weight.device).float()
        uv_weight = uv_weight / uv_weight.mean()
        self.uv_weight = uv_weight[None, :, :, None]
        self.tv_factor = (uv_res - 1) * (uv_res - 1)

        # Setup an optimizer
        if self.options.stage == 'dp':
            self.optimizer = torch.optim.Adam(
                params=list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        else:
            self.optimizer = torch.optim.Adam(
                params=list(self.LNet.parameters()) +
                list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet, 'LNet': self.LNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_uv = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_keypoints_3d = nn.L1Loss(reduction='none').to(
            self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

    def train_step(self, input_batch):
        """Training step."""
        dtype = torch.float32

        if self.options.stage == 'dp':
            self.CNet.train()

            # Grab data from the batch
            has_dp = input_batch['has_dp']
            images = input_batch['img']
            gt_dp_iuv = input_batch['gt_iuv']
            gt_dp_iuv[:, 1:] = gt_dp_iuv[:, 1:] / 255.0
            batch_size = images.shape[0]

            if images.is_cuda and self.options.ngpu > 1:
                pred_dp, dp_feature, codes = data_parallel(
                    self.CNet, images, range(self.options.ngpu))
            else:
                pred_dp, dp_feature, codes = self.CNet(images)

            if self.options.adaptive_weight:
                fit_joint_error = input_batch['fit_joint_error']
                ada_weight = self.error_adaptive_weight(fit_joint_error).type(
                    dtype)
            else:
                # ada_weight = pred_scale.new_ones(batch_size).type(dtype)
                ada_weight = None

            losses = {}
            '''loss on dense pose result'''
            loss_dp_mask, loss_dp_uv = self.dp_loss(pred_dp, gt_dp_iuv, has_dp,
                                                    ada_weight)
            loss_dp_mask = loss_dp_mask * self.options.lam_dp_mask
            loss_dp_uv = loss_dp_uv * self.options.lam_dp_uv
            losses['dp_mask'] = loss_dp_mask
            losses['dp_uv'] = loss_dp_uv
            loss_total = sum(loss for loss in losses.values())
            # Do backprop
            self.optimizer.zero_grad()
            loss_total.backward()
            self.optimizer.step()

            # for visualize
            if (self.step_count + 1) % self.options.summary_steps == 0:
                data = {}
                vis_num = min(4, batch_size)
                data['image'] = input_batch['img_orig'][0:vis_num].detach()
                data['pred_dp'] = pred_dp[0:vis_num].detach()
                data['gt_dp'] = gt_dp_iuv[0:vis_num].detach()
                self.vis_data = data

            # Pack output arguments to be used for visualization in a list
            out_args = {
                key: losses[key].detach().item()
                for key in losses.keys()
            }
            out_args['total'] = loss_total.detach().item()
            self.loss_item = out_args

        elif self.options.stage == 'end':
            self.CNet.train()
            self.LNet.train()

            # Grab data from the batch
            # gt_keypoints_2d = input_batch['keypoints']
            # gt_keypoints_3d = input_batch['pose_3d']
            # gt_keypoints_2d = torch.cat([input_batch['keypoints'], input_batch['keypoints_smpl']], dim=1)
            # gt_keypoints_3d = torch.cat([input_batch['pose_3d'], input_batch['pose_3d_smpl']], dim=1)
            gt_keypoints_2d = input_batch['keypoints']
            gt_keypoints_3d = input_batch['pose_3d']
            has_pose_3d = input_batch['has_pose_3d']

            gt_keypoints_2d_smpl = input_batch['keypoints_smpl']
            gt_keypoints_3d_smpl = input_batch['pose_3d_smpl']
            has_pose_3d_smpl = input_batch['has_pose_3d_smpl']

            gt_pose = input_batch['pose']
            gt_betas = input_batch['betas']
            has_smpl = input_batch['has_smpl']
            has_dp = input_batch['has_dp']
            images = input_batch['img']
            gender = input_batch['gender']

            # images.requires_grad_()
            gt_dp_iuv = input_batch['gt_iuv']
            gt_dp_iuv[:, 1:] = gt_dp_iuv[:, 1:] / 255.0
            batch_size = images.shape[0]

            gt_vertices = images.new_zeros([batch_size, 6890, 3])
            if images.is_cuda and self.options.ngpu > 1:
                with torch.no_grad():
                    gt_vertices[gender < 0] = data_parallel(
                        self.smpl, (gt_pose[gender < 0], gt_betas[gender < 0]),
                        range(self.options.ngpu))
                    gt_vertices[gender == 0] = data_parallel(
                        self.male_smpl,
                        (gt_pose[gender == 0], gt_betas[gender == 0]),
                        range(self.options.ngpu))
                    gt_vertices[gender == 1] = data_parallel(
                        self.female_smpl,
                        (gt_pose[gender == 1], gt_betas[gender == 1]),
                        range(self.options.ngpu))
                    gt_uv_map = data_parallel(self.sampler, gt_vertices,
                                              range(self.options.ngpu))
                pred_dp, dp_feature, codes = data_parallel(
                    self.CNet, images, range(self.options.ngpu))
                pred_uv_map, pred_camera = data_parallel(
                    self.LNet, (pred_dp, dp_feature, codes),
                    range(self.options.ngpu))
            else:
                # gt_vertices = self.smpl(gt_pose, gt_betas)
                with torch.no_grad():
                    gt_vertices[gender < 0] = self.smpl(
                        gt_pose[gender < 0], gt_betas[gender < 0])
                    gt_vertices[gender == 0] = self.male_smpl(
                        gt_pose[gender == 0], gt_betas[gender == 0])
                    gt_vertices[gender == 1] = self.female_smpl(
                        gt_pose[gender == 1], gt_betas[gender == 1])
                    gt_uv_map = self.sampler.get_UV_map(gt_vertices.float())
                pred_dp, dp_feature, codes = self.CNet(images)
                pred_uv_map, pred_camera = self.LNet(pred_dp, dp_feature,
                                                     codes)

            if self.options.adaptive_weight:
                # Get the confidence of the GT mesh, which is used as the weight of loss item.
                # The confidence is related to the fitting error and for the data with GT SMPL parameters,
                # the confidence is 1.0
                fit_joint_error = input_batch['fit_joint_error']
                ada_weight = self.error_adaptive_weight(fit_joint_error).type(
                    dtype)
            else:
                ada_weight = None

            losses = {}
            '''loss on dense pose result'''
            loss_dp_mask, loss_dp_uv = self.dp_loss(pred_dp, gt_dp_iuv, has_dp,
                                                    ada_weight)
            loss_dp_mask = loss_dp_mask * self.options.lam_dp_mask
            loss_dp_uv = loss_dp_uv * self.options.lam_dp_uv
            losses['dp_mask'] = loss_dp_mask
            losses['dp_uv'] = loss_dp_uv
            '''loss on location map'''
            sampled_vertices = self.sampler.resample(
                pred_uv_map.float()).type(dtype)
            loss_uv = self.uv_loss(
                gt_uv_map.float(), pred_uv_map.float(), has_smpl,
                ada_weight).type(dtype) * self.options.lam_uv
            losses['uv'] = loss_uv

            if self.options.lam_tv > 0:
                loss_tv = self.tv_loss(pred_uv_map) * self.options.lam_tv
                losses['tv'] = loss_tv
            '''loss on mesh'''
            if self.options.lam_mesh > 0:
                loss_mesh = self.shape_loss(sampled_vertices, gt_vertices,
                                            has_smpl,
                                            ada_weight) * self.options.lam_mesh
                losses['mesh'] = loss_mesh
            '''loss on joints'''
            weight_key = sampled_vertices.new_ones(batch_size)
            if self.options.gtkey3d_from_mesh:
                # For the data without GT 3D keypoints but with SMPL parameters,
                # we can get the GT 3D keypoints from the mesh.
                # The confidence of the keypoints is related to the confidence of the mesh.
                gt_keypoints_3d_mesh = self.smpl.get_train_joints(gt_vertices)
                gt_keypoints_3d_mesh = torch.cat([
                    gt_keypoints_3d_mesh,
                    gt_keypoints_3d_mesh.new_ones([batch_size, 24, 1])
                ],
                                                 dim=-1)
                valid = has_smpl > has_pose_3d
                gt_keypoints_3d[valid] = gt_keypoints_3d_mesh[valid]
                has_pose_3d[valid] = 1
                if ada_weight is not None:
                    weight_key[valid] = ada_weight[valid]

            sampled_joints_3d = self.smpl.get_train_joints(sampled_vertices)
            loss_keypoints_3d = self.keypoint_3d_loss(sampled_joints_3d,
                                                      gt_keypoints_3d,
                                                      has_pose_3d, weight_key)
            loss_keypoints_3d = loss_keypoints_3d * self.options.lam_key3d
            losses['key3D'] = loss_keypoints_3d

            sampled_joints_2d = orthographic_projection(
                sampled_joints_3d, pred_camera)[:, :, :2]
            loss_keypoints_2d = self.keypoint_loss(
                sampled_joints_2d, gt_keypoints_2d) * self.options.lam_key2d
            losses['key2D'] = loss_keypoints_2d

            # We add the 24 joints of SMPL model for the training on SURREAL dataset.
            weight_key_smpl = sampled_vertices.new_ones(batch_size)
            if self.options.gtkey3d_from_mesh:
                gt_keypoints_3d_mesh = self.smpl.get_smpl_joints(gt_vertices)
                gt_keypoints_3d_mesh = torch.cat([
                    gt_keypoints_3d_mesh,
                    gt_keypoints_3d_mesh.new_ones([batch_size, 24, 1])
                ],
                                                 dim=-1)
                valid = has_smpl > has_pose_3d_smpl
                gt_keypoints_3d_smpl[valid] = gt_keypoints_3d_mesh[valid]
                has_pose_3d_smpl[valid] = 1
                if ada_weight is not None:
                    weight_key_smpl[valid] = ada_weight[valid]

            if self.options.use_smpl_joints:
                sampled_joints_3d_smpl = self.smpl.get_smpl_joints(
                    sampled_vertices)
                loss_keypoints_3d_smpl = self.smpl_keypoint_3d_loss(
                    sampled_joints_3d_smpl, gt_keypoints_3d_smpl,
                    has_pose_3d_smpl, weight_key_smpl)
                loss_keypoints_3d_smpl = loss_keypoints_3d_smpl * self.options.lam_key3d_smpl
                losses['key3D_smpl'] = loss_keypoints_3d_smpl

                sampled_joints_2d_smpl = orthographic_projection(
                    sampled_joints_3d_smpl, pred_camera)[:, :, :2]
                loss_keypoints_2d_smpl = self.keypoint_loss(
                    sampled_joints_2d_smpl,
                    gt_keypoints_2d_smpl) * self.options.lam_key2d_smpl
                losses['key2D_smpl'] = loss_keypoints_2d_smpl
            '''consistent loss'''
            if not self.options.lam_con == 0:
                loss_con = self.consistent_loss(
                    gt_dp_iuv, pred_uv_map, pred_camera,
                    ada_weight) * self.options.lam_con
                losses['con'] = loss_con

            loss_total = sum(loss for loss in losses.values())
            # Do backprop
            self.optimizer.zero_grad()
            loss_total.backward()
            self.optimizer.step()

            # for visualize
            if (self.step_count + 1) % self.options.summary_steps == 0:
                data = {}
                vis_num = min(4, batch_size)
                data['image'] = input_batch['img_orig'][0:vis_num].detach()
                data['gt_vert'] = gt_vertices[0:vis_num].detach()
                data['pred_vert'] = sampled_vertices[0:vis_num].detach()
                data['pred_cam'] = pred_camera[0:vis_num].detach()
                data['pred_joint'] = sampled_joints_2d[0:vis_num].detach()
                data['gt_joint'] = gt_keypoints_2d[0:vis_num].detach()
                data['pred_uv'] = pred_uv_map[0:vis_num].detach()
                data['gt_uv'] = gt_uv_map[0:vis_num].detach()
                data['pred_dp'] = pred_dp[0:vis_num].detach()
                data['gt_dp'] = gt_dp_iuv[0:vis_num].detach()
                self.vis_data = data

            # Pack output arguments to be used for visualization in a list
            out_args = {
                key: losses[key].detach().item()
                for key in losses.keys()
            }
            out_args['total'] = loss_total.detach().item()
            self.loss_item = out_args

        return out_args

    def train_summaries(self, batch, epoch):
        """Tensorboard logging."""
        if self.options.stage == 'dp':
            dtype = self.vis_data['pred_dp'].dtype
            rend_imgs = []
            vis_size = self.vis_data['pred_dp'].shape[0]
            # Do visualization for the first 4 images of the batch
            for i in range(vis_size):
                img = self.vis_data['image'][i].cpu().numpy().transpose(
                    1, 2, 0)
                H, W, C = img.shape
                rend_img = img.transpose(2, 0, 1)

                gt_dp = self.vis_data['gt_dp'][i]
                gt_dp = torch.nn.functional.interpolate(gt_dp[None, :],
                                                        size=[H, W])[0]
                # gt_dp = torch.cat((gt_dp, gt_dp.new_ones(1, H, W)), dim=0).cpu().numpy()
                gt_dp = gt_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, gt_dp), axis=2)

                pred_dp = self.vis_data['pred_dp'][i]
                pred_dp[0] = (pred_dp[0] > 0.5).type(dtype)
                pred_dp[1:] = pred_dp[1:] * pred_dp[0]
                pred_dp = torch.nn.functional.interpolate(pred_dp[None, :],
                                                          size=[H, W])[0]
                pred_dp = pred_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, pred_dp), axis=2)

                # import matplotlib.pyplot as plt
                # plt.imshow(rend_img.transpose([1, 2, 0]))
                rend_imgs.append(torch.from_numpy(rend_img))

            rend_imgs = make_grid(rend_imgs, nrow=1)
            self.summary_writer.add_image('imgs', rend_imgs, self.step_count)

        else:
            gt_keypoints_2d = self.vis_data['gt_joint'].cpu().numpy()
            pred_vertices = self.vis_data['pred_vert']
            pred_keypoints_2d = self.vis_data['pred_joint']
            pred_camera = self.vis_data['pred_cam']
            dtype = pred_camera.dtype
            rend_imgs = []
            vis_size = pred_vertices.shape[0]
            # Do visualization for the first 4 images of the batch
            for i in range(vis_size):
                img = self.vis_data['image'][i].cpu().numpy().transpose(
                    1, 2, 0)
                H, W, C = img.shape

                # Get LSP keypoints from the full list of keypoints
                gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp]
                pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[
                    i, self.to_lsp]
                vertices = pred_vertices[i].cpu().numpy()
                cam = pred_camera[i].cpu().numpy()
                # Visualize reconstruction and detected pose
                rend_img = visualize_reconstruction(img, self.options.img_res,
                                                    gt_keypoints_2d_, vertices,
                                                    pred_keypoints_2d_, cam,
                                                    self.renderer)
                rend_img = rend_img.transpose(2, 0, 1)

                if 'gt_vert' in self.vis_data.keys():
                    rend_img2 = vis_mesh(
                        img,
                        self.vis_data['gt_vert'][i].cpu().numpy(),
                        cam,
                        self.renderer,
                        color='blue')
                    rend_img2 = rend_img2.transpose(2, 0, 1)
                    rend_img = np.concatenate((rend_img, rend_img2), axis=2)

                gt_dp = self.vis_data['gt_dp'][i]
                gt_dp = torch.nn.functional.interpolate(gt_dp[None, :],
                                                        size=[H, W])[0]
                gt_dp = gt_dp.cpu().numpy()
                # gt_dp = torch.cat((gt_dp, gt_dp.new_ones(1, H, W)), dim=0).cpu().numpy()
                rend_img = np.concatenate((rend_img, gt_dp), axis=2)

                pred_dp = self.vis_data['pred_dp'][i]
                pred_dp[0] = (pred_dp[0] > 0.5).type(dtype)
                pred_dp[1:] = pred_dp[1:] * pred_dp[0]
                pred_dp = torch.nn.functional.interpolate(pred_dp[None, :],
                                                          size=[H, W])[0]
                pred_dp = pred_dp.cpu().numpy()
                rend_img = np.concatenate((rend_img, pred_dp), axis=2)

                # import matplotlib.pyplot as plt
                # plt.imshow(rend_img.transpose([1, 2, 0]))
                rend_imgs.append(torch.from_numpy(rend_img))

            rend_imgs = make_grid(rend_imgs, nrow=1)

            uv_maps = []
            for i in range(vis_size):
                uv_temp = torch.cat(
                    (self.vis_data['pred_uv'][i], self.vis_data['gt_uv'][i]),
                    dim=1)
                uv_maps.append(uv_temp.permute(2, 0, 1))

            uv_maps = make_grid(uv_maps, nrow=1)
            uv_maps = uv_maps.abs()
            uv_maps = uv_maps / uv_maps.max()

            # Save results in Tensorboard
            self.summary_writer.add_image('imgs', rend_imgs, self.step_count)
            self.summary_writer.add_image('uv_maps', uv_maps, self.step_count)

        for key in self.loss_item.keys():
            self.summary_writer.add_scalar('loss_' + key, self.loss_item[key],
                                           self.step_count)

    def train(self):
        """Training process."""
        # Run training for num_epochs epochs
        for epoch in range(self.epoch_count, self.options.num_epochs):
            # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch
            train_data_loader = CheckpointDataLoader(
                self.train_ds,
                checkpoint=self.checkpoint,
                batch_size=self.options.batch_size,
                num_workers=self.options.num_workers,
                pin_memory=self.options.pin_memory,
                shuffle=self.options.shuffle_train)

            # Iterate over all batches in an epoch
            batch_len = len(self.train_ds) // self.options.batch_size
            data_stream = tqdm(train_data_loader,
                               desc='Epoch ' + str(epoch),
                               total=len(self.train_ds) //
                               self.options.batch_size,
                               initial=train_data_loader.checkpoint_batch_idx)
            for step, batch in enumerate(
                    data_stream, train_data_loader.checkpoint_batch_idx):
                if time.time() < self.endtime:

                    batch = {
                        k:
                        v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()
                    }

                    loss_dict = self.train_step(batch)
                    self.step_count += 1

                    tqdm_info = 'Epoch:%d| %d/%d ' % (epoch, step, batch_len)
                    for k, v in loss_dict.items():
                        tqdm_info += ' %s:%.4f' % (k, v)
                    data_stream.set_description(tqdm_info)

                    if self.step_count % self.options.summary_steps == 0:
                        self.train_summaries(step, epoch)

                    # Save checkpoint every checkpoint_steps steps
                    if self.step_count % self.options.checkpoint_steps == 0 and self.step_count > 0:
                        self.saver.save_checkpoint(
                            self.models_dict, self.optimizers_dict, epoch,
                            step + 1, self.options.batch_size,
                            train_data_loader.sampler.dataset_perm,
                            self.step_count)
                        tqdm.write('Checkpoint saved')

                    # Run validation every test_steps steps
                    if self.step_count % self.options.test_steps == 0:
                        self.test()

                else:
                    tqdm.write('Timeout reached')
                    self.saver.save_checkpoint(
                        self.models_dict, self.optimizers_dict, epoch, step,
                        self.options.batch_size,
                        train_data_loader.sampler.dataset_perm,
                        self.step_count)
                    tqdm.write('Checkpoint saved')
                    sys.exit(0)

            # load a checkpoint only on startup, for the next epochs just iterate over the dataset as usual
            self.checkpoint = None
            # save checkpoint after each 10 epoch
            if (epoch + 1) % 10 == 0:
                self.saver.save_checkpoint(self.models_dict,
                                           self.optimizers_dict, epoch + 1, 0,
                                           self.options.batch_size, None,
                                           self.step_count)

        self.saver.save_checkpoint(self.models_dict,
                                   self.optimizers_dict,
                                   epoch + 1,
                                   0,
                                   self.options.batch_size,
                                   None,
                                   self.step_count,
                                   checkpoint_filename='final')
        return
def run_evaluation(model, dataset_name, dataset, 
                   mesh, batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """
    
    renderer = PartRenderer()
    
    # Create SMPL model
    smpl = SMPL().cuda()
    
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Transfer model to the GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()
    
    # 2D pose metrics
    pose_2d_m = np.zeros(len(dataset))
    pose_2d_m_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    eval_2d_pose = False
    eval_shape = False

    # Choose appropriate evaluation for each dataset
    if dataset_name == 'up-3d':
        eval_shape = True
    elif dataset_name == 'lsp':
        eval_2d_pose = True

    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl(gt_pose, gt_betas)
        images = batch['img'].to(torch.device('cuda'))
        curr_batch_size = images.shape[0]
        gt_keypoints_2d = batch['keypoints'].cpu().numpy()
        
        # Run inference
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(images)

        # If mask or part evaluation, render the mask and part images
        if eval_2d_pose:
            mask, parts = renderer(pred_vertices, camera)

        # 2D pose evaluation (for LSP)
        if eval_2d_pose:
            for i in range(curr_batch_size):
                gt_kp = gt_keypoints_2d[i, list(range(14))]
                # Get 3D and projected 2D keypoints from the regressed shape
                pred_keypoints_3d = smpl.get_joints(pred_vertices)
                pred_keypoints_2d = orthographic_projection(pred_keypoints_3d, camera)[:, :, :2]
                pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl)
                pred_keypoints_2d_smpl = orthographic_projection(pred_keypoints_3d_smpl, camera.detach())[:, :, :2]
                pred_kp = pred_keypoints_2d.cpu().numpy()[i, list(range(14))]
                pred_kp_smpl = pred_keypoints_2d_smpl.cpu().numpy()[i, list(range(14))]
                # Compute 2D pose losses
                loss_2d_pose = np.sum((gt_kp[: , :2] - pred_kp)**2)
                loss_2d_pose_smpl = np.sum((gt_kp[: , :2] - pred_kp_smpl)**2)

                #print(gt_kp)
                #print()
                #print(pred_kp_smpl)
                #raw_input("Press Enter to continue...")

                pose_2d_m[step * batch_size + i] = loss_2d_pose
                pose_2d_m_smpl[step * batch_size + i] = loss_2d_pose_smpl

        # Shape evaluation (Mean per-vertex error)
        if eval_shape:
            se = torch.sqrt(((pred_vertices - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            se_smpl = torch.sqrt(((pred_vertices_smpl - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            shape_err[step * batch_size:step * batch_size + curr_batch_size] = se
            shape_err_smpl[step * batch_size:step * batch_size + curr_batch_size] = se_smpl

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_2d_pose:
                print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m[:step * batch_size].mean()))
                print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl[:step * batch_size].mean()))
                print()
            if eval_shape:
                print('Shape Error (NonParam): ' + str(1000 * shape_err[:step * batch_size].mean()))
                print('Shape Error (Param): ' + str(1000 * shape_err_smpl[:step * batch_size].mean()))
                print()

    # Print and store final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_2d_pose:
        print('2D keypoints (NonParam): ' + str(1000 * pose_2d_m.mean()))
        print('2D keypoints (Param): ' + str(1000 * pose_2d_m_smpl.mean()))
        print()
        # store results
        #np.savez("../eval-output/CMR_no_extra_lsp_model_alt.npz", imgnames = dataset.imgname, kp_2d_err_graph = pose_2d_m, kp_2d_err_smpl = pose_2d_m_smpl)
    if eval_shape:
        print('Shape Error (NonParam): ' + str(1000 * shape_err.mean()))
        print('Shape Error (Param): ' + str(1000 * shape_err_smpl.mean()))
        print()
Ejemplo n.º 21
0
def run_evaluation(model,
                   dataset_name,
                   dataset,
                   result_file,
                   batch_size=32,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(config.SMPL_MODEL_DIR, create_transl=False).to(device)
    smpl_male = SMPL(config.SMPL_MODEL_DIR, gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)

    renderer = PartRenderer()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(
        config.JOINT_REGRESSOR_H36M)).float()

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas,
                                   body_pose=gt_pose[:, 3:],
                                   global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]

        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0, 0, 1],
                                   dtype=torch.float32,
                                   device=device).view(1, 3, 1)
            rotmat = torch.cat((pred_rotmat.view(
                -1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)),
                               dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(
                rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size +
                      curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size +
                       curr_batch_size, :] = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size +
                        curr_batch_size, :] = pred_camera.cpu().numpy()

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:, :3],
                                        body_pose=gt_pose[:, 3:],
                                        betas=gt_betas).vertices
                gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3],
                                                 body_pose=gt_pose[:, 3:],
                                                 betas=gt_betas).vertices
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender ==
                                                                    1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[
                    step * batch_size:step * batch_size +
                    curr_batch_size, :, :] = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis
            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size +
                  curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(),
                                           gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)
            recon_err[step * batch_size:step * batch_size +
                      curr_batch_size] = r_error

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' +
                      str(1000 * recon_err[:step * batch_size].mean()))
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
        np.savez('error.npz', mpjpe=mpjpe, recon_err=recon_err)
    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
Ejemplo n.º 22
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)
    
    renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle=False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2,1))
    fp = np.zeros((2,1))
    fn = np.zeros((2,1))
    parts_tp = np.zeros((7,1))
    parts_fp = np.zeros((7,1))
    parts_fn = np.zeros((7,1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]
        
        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0,0,1], dtype=torch.float32, device=device).view(1,3,1)
            rotmat = torch.cat((pred_rotmat.view(-1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)), dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size + curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size + curr_batch_size, :]  = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size + curr_batch_size, :]  = pred_camera.cpu().numpy()
            
        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices_female = smpl_female(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices[gender==1, :, :] = gt_vertices_female[gender==1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0],:].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis 


            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[step * batch_size:step * batch_size + curr_batch_size, :, :]  = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0],:].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis 

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None)
            recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error


        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] +=  (~cgt & cpred).sum()
                    fn[c] +=  (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                   cgt = gt_parts == c
                   cpred = pred_parts == c
                   cpred[gt_parts == 255] = 0
                   parts_tp[c] += (cgt & cpred).sum()
                   parts_fp[c] +=  (~cgt & cpred).sum()
                   parts_fn[c] +=  (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' + str(1000 * recon_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
        print()
Ejemplo n.º 23
0
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset,
                                       self.options,
                                       use_IUV=True)
        self.dp_res = int(self.options.img_res // (2**self.options.warp_level))

        self.CNet = DPNet(warp_lv=self.options.warp_level,
                          norm_type=self.options.norm_type).to(self.device)

        self.LNet = get_LNet(self.options).to(self.device)
        self.smpl = SMPL().to(self.device)
        self.female_smpl = SMPL(cfg.FEMALE_SMPL_FILE).to(self.device)
        self.male_smpl = SMPL(cfg.MALE_SMPL_FILE).to(self.device)

        uv_res = self.options.uv_res
        self.uv_type = self.options.uv_type
        self.sampler = Index_UV_Generator(UV_height=uv_res,
                                          UV_width=-1,
                                          uv_type=self.uv_type).to(self.device)

        weight_file = 'data/weight_p24_h{:04d}_w{:04d}_{}.npy'.format(
            uv_res, uv_res, self.uv_type)
        if not os.path.exists(weight_file):
            cal_uv_weight(self.sampler, weight_file)

        uv_weight = torch.from_numpy(np.load(weight_file)).to(
            self.device).float()
        uv_weight = uv_weight * self.sampler.mask.to(uv_weight.device).float()
        uv_weight = uv_weight / uv_weight.mean()
        self.uv_weight = uv_weight[None, :, :, None]
        self.tv_factor = (uv_res - 1) * (uv_res - 1)

        # Setup an optimizer
        if self.options.stage == 'dp':
            self.optimizer = torch.optim.Adam(
                params=list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        else:
            self.optimizer = torch.optim.Adam(
                params=list(self.LNet.parameters()) +
                list(self.CNet.parameters()),
                lr=self.options.lr,
                betas=(self.options.adam_beta1, 0.999),
                weight_decay=self.options.wd)
            self.models_dict = {'CNet': self.CNet, 'LNet': self.LNet}
            self.optimizers_dict = {'optimizer': self.optimizer}

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_uv = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_keypoints_3d = nn.L1Loss(reduction='none').to(
            self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)
Ejemplo n.º 24
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50, bVerbose= True):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # # Transfer model to the GPU
    # model.to(device)

    # Load SMPL model
    global g_smpl_neutral, g_smpl_male, g_smpl_female
    if g_smpl_neutral is None:
        g_smpl_neutral = SMPL(config.SMPL_MODEL_DIR,
                            create_transl=False).to(device)
        g_smpl_male = SMPL(config.SMPL_MODEL_DIR,
                        gender='male',
                        create_transl=False).to(device)
        g_smpl_female = SMPL(config.SMPL_MODEL_DIR,
                        gender='female',
                        create_transl=False).to(device)

        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female
    else:
        smpl_neutral = g_smpl_neutral
        smpl_male = g_smpl_male
        smpl_female = g_smpl_female

    
    # renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle=False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    # mpjpe = np.zeros(len(dataset))
    # recon_err = np.zeros(len(dataset))
    quant_mpjpe = {}#np.zeros(len(dataset))
    quant_recon_err = {}#np.zeros(len(dataset))
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))

    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2,1))
    fp = np.zeros((2,1))
    fn = np.zeros((2,1))
    parts_tp = np.zeros((7,1))
    parts_fp = np.zeros((7,1))
    parts_fn = np.zeros((7,1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    output_pred_pose = np.zeros((len(dataset), 72))
    output_pred_betas = np.zeros((len(dataset), 10))
    output_pred_camera = np.zeros((len(dataset), 3))
    output_pred_joints = np.zeros((len(dataset), 14, 3))

    output_gt_pose = np.zeros((len(dataset), 72))
    output_gt_betas = np.zeros((len(dataset), 10))
    output_gt_joints = np.zeros((len(dataset), 14, 3))

    output_error_MPJPE = np.zeros((len(dataset)))
    output_error_recon = np.zeros((len(dataset)))

    output_imgNames =[]
    output_cropScale  = np.zeros((len(dataset)))
    output_cropCenter = np.zeros((len(dataset), 2))
    outputStartPointer = 0


    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch

        imgName = batch['imgname'][0]
        seqName = os.path.basename ( os.path.dirname(imgName) )

        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl_neutral(betas=gt_betas, body_pose=gt_pose[:, 3:], global_orient=gt_pose[:, :3]).vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]
        
        with torch.no_grad():
            pred_rotmat, pred_betas, pred_camera = model(images)
            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
            pred_vertices = pred_output.vertices

        
    
        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices_female = smpl_female(global_orient=gt_pose[:,:3], body_pose=gt_pose[:,3:], betas=gt_betas).vertices 
                gt_vertices[gender==1, :, :] = gt_vertices_female[gender==1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0],:].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis             

                if False:
                    from renderer import viewer2D
                    from renderer import glViewer
                    import humanModelViewer
                    batchNum = gt_pose.shape[0]
                    for i in range(batchNum):
                        smpl_face = humanModelViewer.GetSMPLFace()
                        meshes_gt = {'ver': gt_vertices[i].cpu().numpy()*100, 'f': smpl_face}
                        meshes_pred = {'ver': pred_vertices[i].cpu().numpy()*100, 'f': smpl_face}

                        glViewer.setMeshData([meshes_gt, meshes_pred], bComputeNormal= True)
                        glViewer.show(5)

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            # if save_results:
            #     pred_joints[step * batch_size:step * batch_size + curr_batch_size, :, :]  = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0],:].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis 

            #Visualize GT mesh and SPIN output mesh
            if False:
                from renderer import viewer2D
                from renderer import glViewer
                import humanModelViewer

                gt_keypoints_3d_vis = gt_keypoints_3d.cpu().numpy()
                gt_keypoints_3d_vis = np.reshape(gt_keypoints_3d_vis, (gt_keypoints_3d_vis.shape[0],-1))        #N,14x3
                gt_keypoints_3d_vis = np.swapaxes(gt_keypoints_3d_vis, 0,1) *100

                pred_keypoints_3d_vis = pred_keypoints_3d.cpu().numpy()
                pred_keypoints_3d_vis = np.reshape(pred_keypoints_3d_vis, (pred_keypoints_3d_vis.shape[0],-1))        #N,14x3
                pred_keypoints_3d_vis = np.swapaxes(pred_keypoints_3d_vis, 0,1) *100
                # output_sample = output_sample[ : , np.newaxis]*0.1
                # gt_sample = gt_sample[: , np.newaxis]*0.1
                # (skelNum, dim, frames)
                glViewer.setSkeleton( [gt_keypoints_3d_vis, pred_keypoints_3d_vis] ,jointType='smplcoco')#(skelNum, dim, frames)
                glViewer.show()
                

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            # mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(), reduction=None)
            # recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            for ii, p in enumerate(batch['imgname'][:len(r_error)]):
                seqName = os.path.basename( os.path.dirname(p))
                # quant_mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error
                if seqName not in quant_mpjpe.keys():
                    quant_mpjpe[seqName] = []
                    quant_recon_err[seqName] = []
                
                quant_mpjpe[seqName].append(error[ii]) 
                quant_recon_err[seqName].append(r_error[ii])

            # Reconstuction_error
            # quant_recon_err[step * batch_size:step * batch_size + curr_batch_size] = r_error

            list_mpjpe = np.hstack([ quant_mpjpe[k] for k in quant_mpjpe])
            list_reconError = np.hstack([ quant_recon_err[k] for k in quant_recon_err])
            if bVerbose:
                print(">>> {} : MPJPE {:.02f} mm, error: {:.02f} mm | Total MPJPE {:.02f} mm, error {:.02f} mm".format(seqName, np.mean(error)*1000, np.mean(r_error)*1000, np.hstack(list_mpjpe).mean()*1000, np.hstack(list_reconError).mean()*1000) )

            # print("MPJPE {}, error: {}".format(np.mean(error)*100, np.mean(r_error)*100))

        # If mask or part evaluation, render the mask and part images
        # if eval_masks or eval_parts:
        #     mask, parts = renderer(pred_vertices, pred_camera)

        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] +=  (~cgt & cpred).sum()
                    fn[c] +=  (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                   cgt = gt_parts == c
                   cpred = pred_parts == c
                   cpred[gt_parts == 255] = 0
                   parts_tp[c] += (cgt & cpred).sum()
                   parts_fp[c] +=  (~cgt & cpred).sum()
                   parts_fn[c] +=  (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if bVerbose:
            if step % log_freq == log_freq - 1:
                if eval_pose:
                    print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                    print('Reconstruction Error: ' + str(1000 * recon_err[:step * batch_size].mean()))
                    print()
                if eval_masks:
                    print('Accuracy: ', accuracy / pixel_count)
                    print('F1: ', f1.mean())
                    print()
                if eval_parts:
                    print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                    print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
                    print()

        if save_results:
            rot_pad = torch.tensor([0,0,1], dtype=torch.float32, device=device).view(1,3,1)
            rotmat = torch.cat((pred_rotmat.view(-1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)), dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(rotmat).contiguous().view(-1, 72)

            output_pred_pose[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_pose.cpu().numpy()
            output_pred_betas[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_betas.cpu().numpy()
            output_pred_camera[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_camera.cpu().numpy()

            output_pred_pose[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_pose.cpu().numpy()
            output_pred_betas[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_betas.cpu().numpy()
            output_pred_camera[outputStartPointer:outputStartPointer+curr_batch_size, :]  = pred_camera.cpu().numpy()
            output_pred_joints[outputStartPointer:outputStartPointer+curr_batch_size, :] = pred_keypoints_3d.cpu().numpy()

            output_gt_pose[outputStartPointer:outputStartPointer+curr_batch_size, :]  = gt_pose.cpu().numpy()
            output_gt_betas[outputStartPointer:outputStartPointer+curr_batch_size, :] = gt_betas.cpu().numpy()
            output_gt_joints[outputStartPointer:outputStartPointer+curr_batch_size, :] = gt_keypoints_3d.cpu().numpy()

            output_error_MPJPE[outputStartPointer:outputStartPointer+curr_batch_size,]  =  error *1000
            output_error_recon[outputStartPointer:outputStartPointer+curr_batch_size] =  r_error*1000

            output_cropScale[outputStartPointer:outputStartPointer+curr_batch_size] = batch['scale'].cpu().numpy()
            output_cropCenter[outputStartPointer:outputStartPointer+curr_batch_size, :] = batch['center'].cpu().numpy()

            output_imgNames +=batch['imgname']

            outputStartPointer +=curr_batch_size

            # if outputStartPointer>100:     #Debug
            #         break


        
    # if len(output_imgNames) < output_pred_pose.shape[0]:
    output ={}
    finalLen = len(output_imgNames)
    output['imageNames'] = output_imgNames
    output['pred_pose'] = output_pred_pose[:finalLen]
    output['pred_betas'] = output_pred_betas[:finalLen]
    output['pred_camera'] = output_pred_camera[:finalLen]
    output['pred_joints'] = output_pred_joints[:finalLen]

    output['gt_pose'] = output_gt_pose[:finalLen]
    output['gt_betas'] = output_gt_betas[:finalLen]
    output['gt_joints'] = output_gt_joints[:finalLen]

    output['error_MPJPE'] = output_error_MPJPE[:finalLen]
    output['error_recon'] = output_error_recon[:finalLen]

    output['cropScale']  = output_cropScale[:finalLen]
    output['cropCenter'] = output_cropCenter[:finalLen]


    # Save reconstructions to a file for further processing
    if save_results:
        import pickle
        # np.savez(result_file, pred_joints=pred_joints, pred_pose=pred_pose, pred_betas=pred_betas, pred_camera=pred_camera)
        with open(result_file,'wb') as f:
            pickle.dump(output, f)
            f.close()
            print("Saved to:{}".format(result_file))
        
    # Print final results during evaluation

    if bVerbose:
        print('*** Final Results ***')
        print()
    if eval_pose:
        # if bVerbose:
        #     print('MPJPE: ' + str(1000 * mpjpe.mean()))
        #     print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        #     print()
        list_mpjpe = np.hstack([ quant_mpjpe[k] for k in quant_mpjpe])
        list_reconError = np.hstack([ quant_recon_err[k] for k in quant_recon_err])

        output_str ='SeqNames; '
        for seq in quant_mpjpe:
            output_str += seq + ';'
        output_str +='\n MPJPE; '
        quant_mpjpe_avg_mm = np.hstack(list_mpjpe).mean()*1000
        output_str += "Avg {:.02f} mm; ".format( quant_mpjpe_avg_mm)
        for seq in quant_mpjpe:
            output_str += '{:.02f}; '.format(1000 * np.hstack(quant_mpjpe[seq]).mean())

        output_str +='\n Recon Error; '
        quant_recon_error_avg_mm = np.hstack(list_reconError).mean()*1000
        output_str +="Avg {:.02f}mm; ".format( quant_recon_error_avg_mm )
        for seq in quant_recon_err:
            output_str += '{:.02f}; '.format(1000 * np.hstack(quant_recon_err[seq]).mean())
        if bVerbose:
            print(output_str)
        else:
            print(">>>  Test on 3DPW: MPJPE: {} | quant_recon_error_avg_mm: {}".format(quant_mpjpe_avg_mm, quant_recon_error_avg_mm) )

       
        return quant_mpjpe_avg_mm, quant_recon_error_avg_mm

    if bVerbose:
        if eval_masks:
            print('Accuracy: ', accuracy / pixel_count)
            print('F1: ', f1.mean())
            print()
        if eval_parts:
            print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
            print('Parts F1 (BG): ', parts_f1[[0,1,2,3,4,5,6]].mean())
            print()

    return -1       #Should return something
Ejemplo n.º 25
0
                     extract_img=True)
        h36m_extract(cfg.H36M_ROOT_ORIGIN,
                     out_path,
                     protocol=2,
                     extract_img=False)

        # LSP dataset preprocessing (test set)
        lsp_dataset_extract(cfg.LSP_ROOT, out_path)

        # UP-3D dataset preprocessing (lsp_test set)
        up_3d_extract(cfg.UP_3D_ROOT, out_path, 'lsp_test')

    if args.gt_iuv:
        device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        smpl = SMPL().to(device)
        uv_type = args.uv_type

        if uv_type == 'SMPL':
            data = objfile.read_obj_full(
                'data/uv_sampler/smpl_fbx_template.obj')
        elif uv_type == 'BF':
            data = objfile.read_obj_full(
                'data/uv_sampler/smpl_boundry_free_template.obj')

        vt = np.array(data['texcoords'])
        face = [f[0] for f in data['faces']]
        face = np.array(face) - 1
        vt_face = [f[2] for f in data['faces']]
        vt_face = np.array(vt_face) - 1
        renderer = UVRenderer(faces=face,
Ejemplo n.º 26
0
    def init_fn(self):
        if self.options.rank == 0:
            self.summary_writer.add_text('command_args', print_args())

        if self.options.regressor == 'hmr':
            # HMR/SPIN model
            self.model = hmr(path_config.SMPL_MEAN_PARAMS, pretrained=True)
            self.smpl = SMPL(path_config.SMPL_MODEL_DIR,
                             batch_size=cfg.TRAIN.BATCH_SIZE,
                             create_transl=False).to(self.device)
        elif self.options.regressor == 'pymaf_net':
            # PyMAF model
            self.model = pymaf_net(path_config.SMPL_MEAN_PARAMS,
                                   pretrained=True)
            self.smpl = self.model.regressor[0].smpl

        if self.options.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices.
            if self.options.gpu is not None:
                torch.cuda.set_device(self.options.gpu)
                self.model.cuda(self.options.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs we have
                self.options.batch_size = int(self.options.batch_size /
                                              self.options.ngpus_per_node)
                self.options.workers = int(
                    (self.options.workers + self.options.ngpus_per_node - 1) /
                    self.options.ngpus_per_node)
                self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.model)
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.options.gpu],
                    output_device=self.options.gpu,
                    find_unused_parameters=True)
            else:
                self.model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model, find_unused_parameters=True)
            self.models_dict = {'model': self.model.module}
        else:
            self.model = self.model.to(self.device)
            self.models_dict = {'model': self.model}

        cudnn.benchmark = True

        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.focal_length = constants.FOCAL_LENGTH

        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=cfg.SOLVER.BASE_LR,
                                          weight_decay=0)

        self.optimizers_dict = {'optimizer': self.optimizer}

        if self.options.single_dataset:
            self.train_ds = BaseDataset(self.options,
                                        self.options.single_dataname,
                                        is_train=True)
        else:
            self.train_ds = MixedDataset(self.options, is_train=True)

        self.valid_ds = BaseDataset(self.options,
                                    self.options.eval_dataset,
                                    is_train=False)

        if self.options.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                self.train_ds)
            val_sampler = None
        else:
            train_sampler = None
            val_sampler = None

        self.train_data_loader = DataLoader(self.train_ds,
                                            batch_size=self.options.batch_size,
                                            num_workers=self.options.workers,
                                            pin_memory=cfg.TRAIN.PIN_MEMORY,
                                            shuffle=(train_sampler is None),
                                            sampler=train_sampler)

        self.valid_loader = DataLoader(dataset=self.valid_ds,
                                       batch_size=cfg.TEST.BATCH_SIZE,
                                       shuffle=False,
                                       num_workers=cfg.TRAIN.NUM_WORKERS,
                                       pin_memory=cfg.TRAIN.PIN_MEMORY,
                                       sampler=val_sampler)

        # Load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)
        self.evaluation_accumulators = dict.fromkeys([
            'pred_j3d', 'target_j3d', 'target_theta', 'pred_verts',
            'target_verts'
        ])

        # Create renderer
        try:
            self.renderer = OpenDRenderer()
        except:
            print('No renderer for visualization.')
            self.renderer = None

        if cfg.MODEL.PyMAF.AUX_SUPV_ON:
            self.iuv_maker = IUV_Renderer(
                output_size=cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)

        self.decay_steps_ind = 1
        self.decay_epochs_ind = 1
Ejemplo n.º 27
0
def run_evaluation(model, dataset):
    """Run evaluation on the datasets and metrics we report in the paper. """

    shuffle = args.shuffle
    log_freq = args.log_freq
    batch_size = args.batch_size
    dataset_name = args.dataset
    result_file = args.result_file
    num_workers = args.num_workers
    device = torch.device('cuda') if torch.cuda.is_available() \
                                else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)

    renderer = PartRenderer()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(
        path_config.JOINT_REGRESSOR_H36M)).float()

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))
    pve = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))
    shape_err_smpl = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))
    action_idxes = {}
    idx_counter = 0
    # for each action
    act_PVE = {}
    act_MPJPE = {}
    act_paMPJPE = {}

    eval_pose = False
    eval_masks = False
    eval_parts = False
    # Choose appropriate evaluation for each dataset
    if dataset_name == 'h36m-p1' or dataset_name == 'h36m-p2' or dataset_name == 'h36m-p2-mosh' \
       or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp' or dataset_name == '3doh50k':
        eval_pose = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = path_config.DATASET_FOLDERS['upi-s1h']

    joint_mapper_h36m = constants.H36M_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.H36M_TO_J14
    joint_mapper_gt = constants.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else constants.J24_TO_J14
    # Iterate over the entire dataset
    cnt = 0
    results_dict = {'id': [], 'pred': [], 'pred_pa': [], 'gt': []}
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_smpl_out = smpl_neutral(betas=gt_betas,
                                   body_pose=gt_pose[:, 3:],
                                   global_orient=gt_pose[:, :3])
        gt_vertices_nt = gt_smpl_out.vertices
        images = batch['img'].to(device)
        gender = batch['gender'].to(device)
        curr_batch_size = images.shape[0]

        if save_results:
            s_id = np.array(
                [int(item.split('/')[-3][-1])
                 for item in batch['imgname']]) * 10000
            s_id += np.array(
                [int(item.split('/')[-1][4:-4]) for item in batch['imgname']])
            results_dict['id'].append(s_id)

        if dataset_name == 'h36m-p2':
            action = [
                im_path.split('/')[-1].split('.')[0].split('_')[1]
                for im_path in batch['imgname']
            ]
            for act_i in range(len(action)):

                if action[act_i] in action_idxes:
                    action_idxes[action[act_i]].append(idx_counter + act_i)
                else:
                    action_idxes[action[act_i]] = [idx_counter + act_i]
            idx_counter += len(action)

        with torch.no_grad():
            if args.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
                # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
            elif args.regressor == 'pymaf_net':
                preds_dict, _ = model(images)
                pred_rotmat = preds_dict['smpl_out'][-1]['rotmat'].contiguous(
                ).view(-1, 24, 3, 3)
                pred_betas = preds_dict['smpl_out'][-1][
                    'theta'][:, 3:13].contiguous()
                pred_camera = preds_dict['smpl_out'][-1][
                    'theta'][:, :3].contiguous()

            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)
            pred_vertices = pred_output.vertices

        if save_results:
            rot_pad = torch.tensor([0, 0, 1],
                                   dtype=torch.float32,
                                   device=device).view(1, 3, 1)
            rotmat = torch.cat((pred_rotmat.view(
                -1, 3, 3), rot_pad.expand(curr_batch_size * 24, -1, -1)),
                               dim=-1)
            pred_pose = tgm.rotation_matrix_to_angle_axis(
                rotmat).contiguous().view(-1, 72)
            smpl_pose[step * batch_size:step * batch_size +
                      curr_batch_size, :] = pred_pose.cpu().numpy()
            smpl_betas[step * batch_size:step * batch_size +
                       curr_batch_size, :] = pred_betas.cpu().numpy()
            smpl_camera[step * batch_size:step * batch_size +
                        curr_batch_size, :] = pred_camera.cpu().numpy()

        # 3D pose evaluation
        if eval_pose:
            # Regressor broadcasting
            J_regressor_batch = J_regressor[None, :].expand(
                pred_vertices.shape[0], -1, -1).to(device)
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name or '3doh50k' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_gt, :-1]
            # For 3DPW get the 14 common joints from the rendered shape
            else:
                gt_vertices = smpl_male(global_orient=gt_pose[:, :3],
                                        body_pose=gt_pose[:, 3:],
                                        betas=gt_betas).vertices
                gt_vertices_female = smpl_female(global_orient=gt_pose[:, :3],
                                                 body_pose=gt_pose[:, 3:],
                                                 betas=gt_betas).vertices
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender ==
                                                                    1, :, :]
                gt_keypoints_3d = torch.matmul(J_regressor_batch, gt_vertices)
                gt_pelvis = gt_keypoints_3d[:, [0], :].clone()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper_h36m, :]
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            if '3dpw' in dataset_name:
                per_vertex_error = torch.sqrt(
                    ((pred_vertices -
                      gt_vertices)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            else:
                per_vertex_error = torch.sqrt(
                    ((pred_vertices - gt_vertices_nt)**2).sum(dim=-1)).mean(
                        dim=-1).cpu().numpy()
            pve[step * batch_size:step * batch_size +
                curr_batch_size] = per_vertex_error

            # Get 14 predicted joints from the mesh
            pred_keypoints_3d = torch.matmul(J_regressor_batch, pred_vertices)
            if save_results:
                pred_joints[
                    step * batch_size:step * batch_size +
                    curr_batch_size, :, :] = pred_keypoints_3d.cpu().numpy()
            pred_pelvis = pred_keypoints_3d[:, [0], :].clone()
            pred_keypoints_3d = pred_keypoints_3d[:, joint_mapper_h36m, :]
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            # Absolute error (MPJPE)
            error = torch.sqrt(
                ((pred_keypoints_3d -
                  gt_keypoints_3d)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size +
                  curr_batch_size] = error

            # Reconstuction_error
            r_error, pred_keypoints_3d_pa = reconstruction_error(
                pred_keypoints_3d.cpu().numpy(),
                gt_keypoints_3d.cpu().numpy(),
                reduction=None)
            recon_err[step * batch_size:step * batch_size +
                      curr_batch_size] = r_error

            if save_results:
                results_dict['gt'].append(gt_keypoints_3d.cpu().numpy())
                results_dict['pred'].append(pred_keypoints_3d.cpu().numpy())
                results_dict['pred_pa'].append(pred_keypoints_3d_pa)

        if args.vis_demo:
            imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]

            if args.regressor == 'hmr':
                iuv_pred = None

            images_vis = images * torch.tensor([0.229, 0.224, 0.225],
                                               device=images.device).reshape(
                                                   1, 3, 1, 1)
            images_vis = images_vis + torch.tensor(
                [0.485, 0.456, 0.406], device=images.device).reshape(
                    1, 3, 1, 1)
            vis_smpl_iuv(
                images_vis.cpu().numpy(),
                pred_camera.cpu().numpy(),
                pred_output.vertices.cpu().numpy(), smpl_neutral.faces,
                iuv_pred, 100 * per_vertex_error, imgnames,
                os.path.join('./notebooks/output/demo_results', dataset_name,
                             args.checkpoint.split('/')[-3]), args)

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, pred_camera)
        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i],
                                   orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(
                    os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8),
                                    center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(
                    os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('Reconstruction Error: ' +
                      str(1000 * recon_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5,
                                                   6]].mean())
                print()

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
        for k in results_dict.keys():
            results_dict[k] = np.concatenate(results_dict[k])
            print(k, results_dict[k].shape)

        scipy.io.savemat(result_file + '.mat', results_dict)

    # Print final results during evaluation
    print('*** Final Results ***')
    try:
        print(os.path.split(args.checkpoint)[-3:], args.dataset)
    except:
        pass
    if eval_pose:
        print('PVE: ' + str(1000 * pve.mean()))
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('Reconstruction Error: ' + str(1000 * recon_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
        print()

    if dataset_name == 'h36m-p2':
        print(
            'Note: PVE is not available for h36m-p2. To evaluate PVE, use h36m-p2-mosh instead.'
        )
        for act in action_idxes:
            act_idx = action_idxes[act]
            act_pve = [pve[i] for i in act_idx]
            act_errors = [mpjpe[i] for i in act_idx]
            act_errors_pa = [recon_err[i] for i in act_idx]

            act_errors_mean = np.mean(np.array(act_errors)) * 1000.
            act_errors_pa_mean = np.mean(np.array(act_errors_pa)) * 1000.
            act_pve_mean = np.mean(np.array(act_pve)) * 1000.
            act_MPJPE[act] = act_errors_mean
            act_paMPJPE[act] = act_errors_pa_mean
            act_PVE[act] = act_pve_mean

        act_err_info = ['action err']
        act_row = [str(act_paMPJPE[act])
                   for act in action_idxes] + [act for act in action_idxes]
        act_err_info.extend(act_row)
        print(act_err_info)
    else:
        act_row = None
Ejemplo n.º 28
0
def run_evaluation(model, opt, options, dataset_name, log_freq=50):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # Create SMPL model
    smpl = SMPL().to(device)
    if dataset_name == '3dpw' or dataset_name == 'surreal':
        smpl_male = SMPL(cfg.MALE_SMPL_FILE).to(device)
        smpl_female = SMPL(cfg.FEMALE_SMPL_FILE).to(device)

    batch_size = opt.batch_size

    # Create dataloader for the dataset
    if dataset_name == 'surreal':
        dataset = SurrealDataset(options, use_augmentation=False, is_train=False, use_IUV=False)
    else:
        dataset = BaseDataset(options, dataset_name, use_augmentation=False, is_train=False, use_IUV=False)

    data_loader = DataLoader(dataset,  batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.num_workers),
                             pin_memory=True)

    print('data loader finish')

    # Transfer model to the GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    mpjpe_pa = np.zeros(len(dataset))

    # Shape metrics
    # Mean per-vertex error
    shape_err = np.zeros(len(dataset))

    # Mask and part metrics
    # Accuracy
    accuracy = 0.
    parts_accuracy = 0.
    # True positive, false positive and false negative
    tp = np.zeros((2, 1))
    fp = np.zeros((2, 1))
    fn = np.zeros((2, 1))
    parts_tp = np.zeros((7, 1))
    parts_fp = np.zeros((7, 1))
    parts_fn = np.zeros((7, 1))
    # Pixel count accumulators
    pixel_count = 0
    parts_pixel_count = 0

    eval_pose = False
    eval_shape = False
    eval_masks = False
    eval_parts = False
    joint_mapper = cfg.J24_TO_J17 if dataset_name == 'mpi-inf-3dhp' else cfg.J24_TO_J14
    # Choose appropriate evaluation for each dataset
    if 'h36m' in dataset_name or dataset_name == '3dpw' or dataset_name == 'mpi-inf-3dhp':
        eval_pose = True
    elif dataset_name in ['up-3d', 'surreal']:
        eval_shape = True
    elif dataset_name == 'lsp':
        eval_masks = True
        eval_parts = True
        annot_path = cfg.DATASET_FOLDERS['upi-s1h']

    if eval_parts or eval_masks:
        from utils.part_utils import PartRenderer
        renderer = PartRenderer()

    # Iterate over the entire dataset
    for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
        # Get ground truth annotations from the batch
        gt_pose = batch['pose'].to(device)
        gt_betas = batch['betas'].to(device)
        gt_vertices = smpl(gt_pose, gt_betas)
        images = batch['img'].to(device)

        curr_batch_size = images.shape[0]

        # Run inference
        with torch.no_grad():
            out_dict = model(images)

        pred_vertices = out_dict['pred_vertices']
        camera = out_dict['camera']
        # 3D pose evaluation
        if eval_pose:
            # Get 14 ground truth joints
            if 'h36m' in dataset_name or 'mpi-inf' in dataset_name:
                gt_keypoints_3d = batch['pose_3d'].cuda()
                gt_keypoints_3d = gt_keypoints_3d[:, joint_mapper, :-1]
                gt_pelvis = (gt_keypoints_3d[:, [2]] + gt_keypoints_3d[:, [3]]) / 2
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis
            else:
                gender = batch['gender'].to(device)
                gt_vertices = smpl_male(gt_pose, gt_betas)
                gt_vertices_female = smpl_female(gt_pose, gt_betas)
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender == 1, :, :]

                gt_keypoints_3d = smpl.get_train_joints(gt_vertices)[:, joint_mapper]
                # gt_keypoints_3d = smpl.get_lsp_joints(gt_vertices)    # joints_regressor used in cmr
                gt_pelvis = (gt_keypoints_3d[:, [2]] + gt_keypoints_3d[:, [3]]) / 2
                gt_keypoints_3d = gt_keypoints_3d - gt_pelvis

            # Get 14 predicted joints from the non-parametic mesh
            pred_keypoints_3d = smpl.get_train_joints(pred_vertices)[:, joint_mapper]
            # pred_keypoints_3d = smpl.get_lsp_joints(pred_vertices)    # joints_regressor used in cmr
            pred_pelvis = (pred_keypoints_3d[:, [2]] + pred_keypoints_3d[:, [3]]) / 2
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis

            # Absolute error (MPJPE)
            error = torch.sqrt(((pred_keypoints_3d - gt_keypoints_3d) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            mpjpe[step * batch_size:step * batch_size + curr_batch_size] = error

            # Reconstuction_error
            r_error = reconstruction_error(pred_keypoints_3d.cpu().numpy(), gt_keypoints_3d.cpu().numpy(),
                                           reduction=None)
            mpjpe_pa[step * batch_size:step * batch_size + curr_batch_size] = r_error

        # Shape evaluation (Mean per-vertex error)
        if eval_shape:
            if dataset_name == 'surreal':
                gender = batch['gender'].to(device)
                gt_vertices = smpl_male(gt_pose, gt_betas)
                gt_vertices_female = smpl_female(gt_pose, gt_betas)
                gt_vertices[gender == 1, :, :] = gt_vertices_female[gender == 1, :, :]

            gt_pelvis_mesh = smpl.get_eval_joints(gt_vertices)
            pred_pelvis_mesh = smpl.get_eval_joints(pred_vertices)
            gt_pelvis_mesh = (gt_pelvis_mesh[:, [2]] + gt_pelvis_mesh[:, [3]]) / 2
            pred_pelvis_mesh = (pred_pelvis_mesh[:, [2]] + pred_pelvis_mesh[:, [3]]) / 2

            # se = torch.sqrt(((pred_vertices - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            se = torch.sqrt(((pred_vertices - pred_pelvis_mesh - gt_vertices + gt_pelvis_mesh) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
            shape_err[step * batch_size:step * batch_size + curr_batch_size] = se

        # If mask or part evaluation, render the mask and part images
        if eval_masks or eval_parts:
            mask, parts = renderer(pred_vertices, camera)
        # Mask evaluation (for LSP)
        if eval_masks:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            # Dimensions of original image
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                # After rendering, convert imate back to original resolution
                pred_mask = uncrop(mask[i].cpu().numpy(), center[i], scale[i], orig_shape[i]) > 0
                # Load gt mask
                gt_mask = cv2.imread(os.path.join(annot_path, batch['maskname'][i]), 0) > 0
                # Evaluation consistent with the original UP-3D code
                accuracy += (gt_mask == pred_mask).sum()
                pixel_count += np.prod(np.array(gt_mask.shape))
                for c in range(2):
                    cgt = gt_mask == c
                    cpred = pred_mask == c
                    tp[c] += (cgt & cpred).sum()
                    fp[c] += (~cgt & cpred).sum()
                    fn[c] += (cgt & ~cpred).sum()
                f1 = 2 * tp / (2 * tp + fp + fn)

        # Part evaluation (for LSP)
        if eval_parts:
            center = batch['center'].cpu().numpy()
            scale = batch['scale'].cpu().numpy()
            orig_shape = batch['orig_shape'].cpu().numpy()
            for i in range(curr_batch_size):
                pred_parts = uncrop(parts[i].cpu().numpy().astype(np.uint8), center[i], scale[i], orig_shape[i])
                # Load gt part segmentation
                gt_parts = cv2.imread(os.path.join(annot_path, batch['partname'][i]), 0)
                # Evaluation consistent with the original UP-3D code
                # 6 parts + background
                for c in range(7):
                    cgt = gt_parts == c
                    cpred = pred_parts == c
                    cpred[gt_parts == 255] = 0
                    parts_tp[c] += (cgt & cpred).sum()
                    parts_fp[c] += (~cgt & cpred).sum()
                    parts_fn[c] += (cgt & ~cpred).sum()
                gt_parts[gt_parts == 255] = 0
                pred_parts[pred_parts == 255] = 0
                parts_f1 = 2 * parts_tp / (2 * parts_tp + parts_fp + parts_fn)
                parts_accuracy += (gt_parts == pred_parts).sum()
                parts_pixel_count += np.prod(np.array(gt_parts.shape))

        # Print intermediate results during evaluation
        if step % log_freq == log_freq - 1:
            if eval_pose:
                print('MPJPE: ' + str(1000 * mpjpe[:step * batch_size].mean()))
                print('MPJPE-PA: ' + str(1000 * mpjpe_pa[:step * batch_size].mean()))
                print()
            if eval_shape:
                print('Shape Error: ' + str(1000 * shape_err[:step * batch_size].mean()))
                print()
            if eval_masks:
                print('Accuracy: ', accuracy / pixel_count)
                print('F1: ', f1.mean())
                print()
            if eval_parts:
                print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
                print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
                print()

    # Print final results during evaluation
    print('*** Final Results ***')
    print()
    if eval_pose:
        print('MPJPE: ' + str(1000 * mpjpe.mean()))
        print('MPJPE-PA: ' + str(1000 * mpjpe_pa.mean()))
        print()
    if eval_shape:
        print('Shape Error: ' + str(1000 * shape_err.mean()))
        print()
    if eval_masks:
        print('Accuracy: ', accuracy / pixel_count)
        print('F1: ', f1.mean())
        print()
    if eval_parts:
        print('Parts Accuracy: ', parts_accuracy / parts_pixel_count)
        print('Parts F1 (BG): ', parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean())
        print()

    # Save final results to .txt file
    txt_name = join(opt.save_root, dataset_name + '.txt')
    f = open(txt_name, 'w')
    f.write('*** Final Results ***')
    f.write('\n')
    if eval_pose:
        f.write('MPJPE: ' + str(1000 * mpjpe.mean()))
        f.write('\n')
        f.write('MPJPE-PA: ' + str(1000 * mpjpe_pa.mean()))
        f.write('\n')
    if eval_shape:
        f.write('Shape Error: ' + str(1000 * shape_err.mean()))
        f.write('\n')
    if eval_masks:
        f.write('Accuracy: ' + str(accuracy / pixel_count))
        f.write('\n')
        f.write('F1: ' + str(f1.mean()))
        f.write('\n')
    if eval_parts:
        f.write('Parts Accuracy: ' + str(parts_accuracy / parts_pixel_count))
        f.write('\n')
        f.write('Parts F1 (BG): ' + str(parts_f1[[0, 1, 2, 3, 4, 5, 6]].mean()))
        f.write('\n')
Ejemplo n.º 29
0
def run_evaluation(model, args, dataset, mesh):
    """Run evaluation on the datasets and metrics we report in the paper. """

    # Create SMPL model
    smpl = SMPL().cuda()

    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M)).float()

    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    # Transfer model to the GPU
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()

    # predictions
    all_kps = {}

    # Iterate over the entire dataset
    for step, batch in enumerate(
            tqdm(data_loader, desc='Eval', total=args.num_imgs)):

        # Get ground truth annotations from the batch
        images = batch['img'].to(device)
        curr_batch_size = images.shape[0]

        # Run inference
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, camera, pred_rotmat, pred_betas = model(
                images)
            pred_keypoints_3d_smpl = smpl.get_joints(pred_vertices_smpl)
            pred_keypoints_2d_smpl = orthographic_projection(
                pred_keypoints_3d_smpl,
                camera.detach())[:, :, :2].cpu().data.numpy()

        eval_part = np.zeros((1, 19, 2))

        # we use custom keypoints for evaluation: MPII + COCO face joints
        # see paper / supplementary for details
        eval_part[0, :14, :] = pred_keypoints_2d_smpl[0][:14]
        eval_part[0, 14:, :] = pred_keypoints_2d_smpl[0][19:]

        all_kps[step] = eval_part
        if args.write_imgs == 'True':
            renderer = Renderer(faces=smpl.faces.cpu().numpy())
            write_imgs(batch, pred_keypoints_2d_smpl, pred_vertices_smpl,
                       camera, args, renderer)

    if args.eval_pck == 'True':
        gt_kp_path = os.path.join(cfg.BASE_DATA_DIR, args.dataset,
                                  args.crop_setting, 'keypoints.pkl')
        log_dir = os.path.join(cfg.BASE_DATA_DIR, 'cmr_pck_results.txt')
        with open(gt_kp_path, 'rb') as f:
            gt = pkl.load(f)

        calc = CalcPCK(
            all_kps,
            gt,
            num_imgs=cfg.DATASET_SIZES[args.dataset][args.crop_setting],
            log_dir=log_dir,
            dataset=args.dataset,
            crop_setting=args.crop_setting,
            pck_eval_threshold=args.pck_eval_threshold)
        calc.eval()