예제 #1
0
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset, self.options)

        # create Mesh object
        self.mesh = Mesh()
        self.faces = self.mesh.faces.to(self.device)

        # create GraphCNN
        self.graph_cnn = GraphCNN(self.mesh.adjmat,
                                  self.mesh.ref_vertices.t(),
                                  num_channels=self.options.num_channels,
                                  num_layers=self.options.num_layers).to(
                                      self.device)

        # SMPL Parameter regressor
        self.smpl_param_regressor = SMPLParamRegressor().to(self.device)

        # Setup a joint optimizer for the 2 models
        self.optimizer = torch.optim.Adam(
            params=list(self.graph_cnn.parameters()) +
            list(self.smpl_param_regressor.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)

        # SMPL model
        self.smpl = SMPL().to(self.device)

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {
            'graph_cnn': self.graph_cnn,
            'smpl_param_regressor': self.smpl_param_regressor
        }
        self.optimizers_dict = {'optimizer': self.optimizer}

        # Renderer for visualization
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)
예제 #2
0
def predict_on_frames(args):
    # Load model
    mesh = Mesh(device=device)
    # Our pretrained networks have 5 residual blocks with 256 channels.
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint, device=device)
    model.to(device)
    model.eval()

    image_paths = [os.path.join(args.in_folder, f) for f in sorted(os.listdir(args.in_folder))
                   if f.endswith('.png')]
    print('Predicting on all png images in {}'.format(args.in_folder))

    all_vertices = []
    all_vertices_smpl = []
    all_cams = []

    for image_path in image_paths:
        print("Image: ", image_path)
        # Preprocess input image and generate predictions
        img, norm_img = process_image(image_path, input_res=cfg.INPUT_RES)
        norm_img = norm_img.to(device)
        with torch.no_grad():
            pred_vertices, pred_vertices_smpl, pred_camera, _, _ = model(norm_img)

        pred_vertices = pred_vertices.cpu().numpy()
        pred_vertices_smpl = pred_vertices_smpl.cpu().numpy()
        pred_camera = pred_camera.cpu().numpy()

        all_vertices.append(pred_vertices)
        all_vertices_smpl.append(pred_vertices_smpl)
        all_cams.append(pred_camera)

    # Save predictions as pkl
    all_vertices = np.concatenate(all_vertices, axis=0)
    all_vertices_smpl = np.concatenate(all_vertices_smpl, axis=0)
    all_cams = np.concatenate(all_cams, axis=0)

    pred_dict = {'verts': all_vertices,
                 'verts_smpl': all_vertices_smpl,
                 'pred_cam': all_cams}
    if args.out_folder == 'dataset':
        out_folder = args.in_folder.replace('cropped_frames', 'cmr_results')
    else:
        out_folder = args.out_folder
    print('Saving to', os.path.join(out_folder, 'cmr_results.pkl'))
    os.makedirs(out_folder)
    for key in pred_dict.keys():
        print(pred_dict[key].shape)
    with open(os.path.join(out_folder, 'cmr_results.pkl'), 'wb') as f:
        pickle.dump(pred_dict, f)
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', default=None, help='Path to network checkpoint')
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--num_workers', default=4, type=int, help='Number of processes for data loading')
    parser.add_argument('--path_correction', action='store_true')
    args = parser.parse_args()

    # Device
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Load model
    mesh = Mesh(device=device)
    # Our pretrained networks have 5 residual blocks with 256 channels.
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint, device=device)
    model.to(device)
    model.eval()

    # Setup evaluation dataset
    dataset_path = '/scratch/as2562/datasets/sports_videos_smpl/final_dataset'
    dataset = SportsVideosEvalDataset(dataset_path, img_wh=config.INPUT_RES,
                                      path_correction=args.path_correction)
    print("Eval examples found:", len(dataset))

    # Metrics
    metrics = ['pve', 'pve_scale_corrected', 'pve_pa', 'pve-t', 'pve-t_scale_corrected',
               'silhouette_iou', 'j2d_l2e']
예제 #4
0
    else:
        if bbox_file is not None:
            center, scale = bbox_from_json(bbox_file)
        elif openpose_file is not None:
            center, scale = bbox_from_openpose(openpose_file)
    img = crop(img, center, scale, (input_res, input_res))
    img = img.astype(np.float32) / 255.
    img = torch.from_numpy(img).permute(2,0,1)
    norm_img = normalize_img(img.clone())[None]
    return img, norm_img

if __name__ == '__main__':
    args = parser.parse_args()
    
    # Load model
    mesh = Mesh()
    # Our pretrained networks have 5 residual blocks with 256 channels. 
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint)
    model.cuda()
    model.eval()

    # Setup renderer for visualization
    renderer = Renderer()

    # Preprocess input image and generate predictions
    img, norm_img = process_image(args.img, args.bbox, args.openpose, input_res=cfg.INPUT_RES)
    with torch.no_grad():
        pred_vertices, pred_vertices_smpl, pred_camera, _, _ = model(norm_img.cuda())
        
    # Calculate camera parameters for rendering
예제 #5
0
    def init(self):
        
        # Create training and testing dataset
        self.train_ds = MGNDataset(self.options, split = 'train')
        self.test_ds  = MGNDataset(self.options, split = 'test')
        
        # test data loader is fixed and disable shuffle as it is unnecessary.
        self.test_data_loader = CheckpointDataLoader( self.test_ds,
                    batch_size  = self.options.batch_size,
                    num_workers = self.options.num_workers,
                    pin_memory  = self.options.pin_memory,
                    shuffle     = False)

        # Create SMPL Mesh (graph) object for GCN
        self.mesh = Mesh(self.options, self.options.num_downsampling)
        self.faces = torch.cat( self.options.batch_size * [
                                self.mesh.faces.to(self.device)[None]
                                ], dim = 0 )
        self.faces = self.faces.type(torch.LongTensor).to(self.device)
        
        # Create SMPL blending model and edges
        self.smpl = SMPL(self.options.smpl_model_path, self.device)
        # self.smplEdge = torch.Tensor(np.load(self.options.smpl_edges_path)) \
        #                 .long().to(self.device)
        
        # create SMPL+D blending model
        self.smplD = Smpl( self.options.smpl_model_path ) 
        
        # read SMPL .bj file to get uv coordinates
        _, self.smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(self.options.smpl_objfile_path)
        uv_coord[:, 1] = 1 - uv_coord[:, 1]
        expUV = uv_coord[tri_uv_ind.flatten()]
        unique, index = np.unique(self.smpl_tri_ind.flatten(), return_index = True)
        self.smpl_verts_uvs = torch.as_tensor(expUV[index,:]).float().to(self.device)
        self.smpl_tri_ind   = torch.as_tensor(self.smpl_tri_ind).to(self.device)
        
        # camera for projection 
        self.perspCam = perspCamera() 

        # mean and std of displacements
        self.dispPara = \
            torch.Tensor(np.load(self.options.MGN_offsMeanStd_path)).to(self.device)
        
        # load average pose and shape and convert to camera coodinate;
        # avg pose is decided by the image id we use for training (0-11) 
        avgPose_objCoord = np.load(self.options.MGN_avgPose_path)
        avgPose_objCoord[:3] = rotationMatrix_to_axisAngle(    # for 0,6, front only
            torch.tensor([[[1,  0, 0],
                           [0, -1, 0],
                           [0,  0,-1]]]))
        self.avgPose = \
            axisAngle_to_Rot6d(
                torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3)
                    ).reshape(1, -1).to(self.device)
        self.avgBeta = \
            torch.Tensor(
                np.load(self.options.MGN_avgBeta_path)[None]).to(self.device)
        self.avgCam  = torch.Tensor([1.2755, 0, 0])[None].to(self.device)    # 1.2755 is for our settings
        
        self.model = frameVIBE(
            self.options.smpl_model_path, 
            self.mesh,                
            self.avgPose,
            self.avgBeta,
            self.avgCam,
            self.options.num_channels,
            self.options.num_layers ,  
            self.smpl_verts_uvs,
            self.smpl_tri_ind
            ).to(self.device)
            
        # Setup a optimizer for models
        self.optimizer = torch.optim.Adam(
            params=list(self.model.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)
        
        self.criterion = VIBELoss(
            e_loss_weight=50,         # for kp 2d, help to estimate camera, global orientation
            e_3d_loss_weight=1,       # for kp 3d, bvt
            e_pose_loss_weight=10,     # for body pose parameters
            e_shape_loss_weight=1,   # for body shape parameters
            e_disp_loss_weight=1,   # for displacements 
            e_tex_loss_weight=1,       # for uv image 
            d_motion_loss_weight=0
            )
        
        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {self.options.model: self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
예제 #6
0
class Trainer(BaseTrainer):
    """Trainer object.
    Inherits from BaseTrainer that sets up logging, saving/restoring checkpoints etc.
    """
    def init_fn(self):
        # create training dataset
        self.train_ds = create_dataset(self.options.dataset, self.options)

        # create Mesh object
        self.mesh = Mesh()
        self.faces = self.mesh.faces.to(self.device)

        # create GraphCNN
        self.graph_cnn = GraphCNN(self.mesh.adjmat,
                                  self.mesh.ref_vertices.t(),
                                  num_channels=self.options.num_channels,
                                  num_layers=self.options.num_layers).to(
                                      self.device)

        # SMPL Parameter regressor
        self.smpl_param_regressor = SMPLParamRegressor().to(self.device)

        # Setup a joint optimizer for the 2 models
        self.optimizer = torch.optim.Adam(
            params=list(self.graph_cnn.parameters()) +
            list(self.smpl_param_regressor.parameters()),
            lr=self.options.lr,
            betas=(self.options.adam_beta1, 0.999),
            weight_decay=self.options.wd)

        # SMPL model
        self.smpl = SMPL().to(self.device)

        # Create loss functions
        self.criterion_shape = nn.L1Loss().to(self.device)
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        self.criterion_regr = nn.MSELoss().to(self.device)

        # Pack models and optimizers in a dict - necessary for checkpointing
        self.models_dict = {
            'graph_cnn': self.graph_cnn,
            'smpl_param_regressor': self.smpl_param_regressor
        }
        self.optimizers_dict = {'optimizer': self.optimizer}

        # Renderer for visualization
        self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())

        # LSP indices from full list of keypoints
        self.to_lsp = list(range(14))

        # Optionally start training from a pretrained checkpoint
        # Note that this is different from resuming training
        # For the latter use --resume
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

    def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d):
        """Compute 2D reprojection loss on the keypoints.
        The confidence is binary and indicates whether the keypoints exist or not.
        The available keypoints are different for each dataset.
        """
        conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
        loss = (conf * self.criterion_keypoints(
            pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
        return loss

    def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d,
                         has_pose_3d):
        """Compute 3D keypoint loss for the examples that 3D keypoint annotations are available.
        The loss is weighted by the confidence
        """
        conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
        gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
        gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
        conf = conf[has_pose_3d == 1]
        pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
        if len(gt_keypoints_3d) > 0:
            gt_pelvis = (gt_keypoints_3d[:, 2, :] +
                         gt_keypoints_3d[:, 3, :]) / 2
            gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
            pred_pelvis = (pred_keypoints_3d[:, 2, :] +
                           pred_keypoints_3d[:, 3, :]) / 2
            pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
            return (conf * self.criterion_keypoints(pred_keypoints_3d,
                                                    gt_keypoints_3d)).mean()
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def shape_loss(self, pred_vertices, gt_vertices, has_smpl):
        """Compute per-vertex loss on the shape for the examples that SMPL annotations are available."""
        pred_vertices_with_shape = pred_vertices[has_smpl == 1]
        gt_vertices_with_shape = gt_vertices[has_smpl == 1]
        if len(gt_vertices_with_shape) > 0:
            return self.criterion_shape(pred_vertices_with_shape,
                                        gt_vertices_with_shape)
        else:
            return torch.FloatTensor(1).fill_(0.).to(self.device)

    def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
                    has_smpl):
        """Compute SMPL parameter loss for the examples that SMPL annotations are available."""
        pred_rotmat_valid = pred_rotmat[has_smpl == 1].view(-1, 3, 3)
        gt_rotmat_valid = rodrigues(gt_pose[has_smpl == 1].view(-1, 3))
        pred_betas_valid = pred_betas[has_smpl == 1]
        gt_betas_valid = gt_betas[has_smpl == 1]
        if len(pred_rotmat_valid) > 0:
            loss_regr_pose = self.criterion_regr(pred_rotmat_valid,
                                                 gt_rotmat_valid)
            loss_regr_betas = self.criterion_regr(pred_betas_valid,
                                                  gt_betas_valid)
        else:
            loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device)
            loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device)
        return loss_regr_pose, loss_regr_betas

    def train_step(self, input_batch):
        """Training step."""
        self.graph_cnn.train()
        self.smpl_param_regressor.train()

        # Grab data from the batch
        gt_keypoints_2d = input_batch['keypoints']
        gt_keypoints_3d = input_batch['pose_3d']
        gt_pose = input_batch['pose']
        gt_betas = input_batch['betas']
        has_smpl = input_batch['has_smpl']
        has_pose_3d = input_batch['has_pose_3d']
        images = input_batch['img']

        # Render vertices using SMPL parameters
        gt_vertices = self.smpl(gt_pose, gt_betas)
        batch_size = gt_vertices.shape[0]

        # Feed image in the GraphCNN
        # Returns subsampled mesh and camera parameters
        pred_vertices_sub, pred_camera = self.graph_cnn(images)

        # Upsample mesh in the original size
        pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1, 2))

        # Prepare input for SMPL Parameter regressor
        # The input is the predicted and template vertices subsampled by a factor of 4
        # Notice that we detach the GraphCNN
        x = pred_vertices_sub.transpose(1, 2).detach()
        x = torch.cat(
            [x, self.mesh.ref_vertices[None, :, :].expand(batch_size, -1, -1)],
            dim=-1)

        # Estimate SMPL parameters and render vertices
        pred_rotmat, pred_shape = self.smpl_param_regressor(x)
        pred_vertices_smpl = self.smpl(pred_rotmat, pred_shape)

        # Get 3D and projected 2D keypoints from the regressed shape
        pred_keypoints_3d = self.smpl.get_joints(pred_vertices)
        pred_keypoints_2d = orthographic_projection(pred_keypoints_3d,
                                                    pred_camera)[:, :, :2]
        pred_keypoints_3d_smpl = self.smpl.get_joints(pred_vertices_smpl)
        pred_keypoints_2d_smpl = orthographic_projection(
            pred_keypoints_3d_smpl, pred_camera.detach())[:, :, :2]

        # Compute losses

        # GraphCNN losses
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d)
        loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d,
                                                  gt_keypoints_3d, has_pose_3d)
        loss_shape = self.shape_loss(pred_vertices, gt_vertices, has_smpl)

        # SMPL regressor losses
        loss_keypoints_smpl = self.keypoint_loss(pred_keypoints_2d_smpl,
                                                 gt_keypoints_2d)
        loss_keypoints_3d_smpl = self.keypoint_3d_loss(pred_keypoints_3d_smpl,
                                                       gt_keypoints_3d,
                                                       has_pose_3d)
        loss_shape_smpl = self.shape_loss(pred_vertices_smpl, gt_vertices,
                                          has_smpl)
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_shape, gt_pose, gt_betas, has_smpl)

        # Add losses to compute the total loss
        loss = loss_shape_smpl + loss_keypoints_smpl + loss_keypoints_3d_smpl +\
               loss_regr_pose + 0.1 * loss_regr_betas + loss_shape + loss_keypoints + loss_keypoints_3d

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments to be used for visualization in a list
        out_args = [
            pred_vertices, pred_vertices_smpl, pred_camera, pred_keypoints_2d,
            pred_keypoints_2d_smpl, loss_shape, loss_shape_smpl,
            loss_keypoints, loss_keypoints_smpl, loss_keypoints_3d,
            loss_keypoints_3d_smpl, loss_regr_pose, loss_regr_betas, loss
        ]
        out_args = [arg.detach() for arg in out_args]
        return out_args

    def train_summaries(self, input_batch, pred_vertices, pred_vertices_smpl,
                        pred_camera, pred_keypoints_2d, pred_keypoints_2d_smpl,
                        loss_shape, loss_shape_smpl, loss_keypoints,
                        loss_keypoints_smpl, loss_keypoints_3d,
                        loss_keypoints_3d_smpl, loss_regr_pose,
                        loss_regr_betas, loss):
        """Tensorboard logging."""
        gt_keypoints_2d = input_batch['keypoints'].cpu().numpy()

        rend_imgs = []
        rend_imgs_smpl = []
        batch_size = pred_vertices.shape[0]
        # Do visualization for the first 4 images of the batch
        for i in range(min(batch_size, 4)):
            img = input_batch['img_orig'][i].cpu().numpy().transpose(1, 2, 0)
            # Get LSP keypoints from the full list of keypoints
            gt_keypoints_2d_ = gt_keypoints_2d[i, self.to_lsp]
            pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i,
                                                                 self.to_lsp]
            pred_keypoints_2d_smpl_ = pred_keypoints_2d_smpl.cpu().numpy()[
                i, self.to_lsp]
            # Get GraphCNN and SMPL vertices for the particular example
            vertices = pred_vertices[i].cpu().numpy()
            vertices_smpl = pred_vertices_smpl[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            cam = pred_camera[i].cpu().numpy()
            # Visualize reconstruction and detected pose
            rend_img = visualize_reconstruction(img, self.options.img_res,
                                                gt_keypoints_2d_, vertices,
                                                pred_keypoints_2d_, cam,
                                                self.renderer)
            rend_img_smpl = visualize_reconstruction(img, self.options.img_res,
                                                     gt_keypoints_2d_,
                                                     vertices_smpl,
                                                     pred_keypoints_2d_smpl_,
                                                     cam, self.renderer)
            rend_img = rend_img.transpose(2, 0, 1)
            rend_img_smpl = rend_img_smpl.transpose(2, 0, 1)
            rend_imgs.append(torch.from_numpy(rend_img))
            rend_imgs_smpl.append(torch.from_numpy(rend_img_smpl))
        rend_imgs = make_grid(rend_imgs, nrow=1)
        rend_imgs_smpl = make_grid(rend_imgs_smpl, nrow=1)

        # Save results in Tensorboard
        self.summary_writer.add_image('imgs', rend_imgs, self.step_count)
        self.summary_writer.add_image('imgs_smpl', rend_imgs_smpl,
                                      self.step_count)
        self.summary_writer.add_scalar('loss_shape', loss_shape,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_shape_smpl', loss_shape_smpl,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_pose', loss_regr_pose,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_regr_betas', loss_regr_betas,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints', loss_keypoints,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_smpl',
                                       loss_keypoints_smpl, self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d', loss_keypoints_3d,
                                       self.step_count)
        self.summary_writer.add_scalar('loss_keypoints_3d_smpl',
                                       loss_keypoints_3d_smpl, self.step_count)
        self.summary_writer.add_scalar('loss', loss, self.step_count)
예제 #7
0
    plt.imshow(img_smpl)
    plt.title('3D Mesh overlay')
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(img_smpl2)
    plt.title('3D mesh')
    plt.axis('off')
    plt.draw()
    plt.show()
    plt.savefig(args.img[:-4]+'_CMRpreds'+'.png', dpi=400)

if __name__ == '__main__':
    args = parser.parse_args()
    
    # Load model
    mesh = Mesh(device=DEVICE)
    # Our pretrained networks have 5 residual blocks with 256 channels. 
    # You might want to change this if you use a different architecture.
    model = CMR(mesh, 5, 256, pretrained_checkpoint=args.checkpoint)
    if DEVICE == torch.device("cuda"):
        model.cuda()
    model.eval()

    # Setup renderer for visualization
    renderer = Renderer()

    # Preprocess input image and generate predictions
    img, norm_img = process_image(args.img, input_res=cfg.INPUT_RES)
    if DEVICE == torch.device("cuda"):
        norm_img = norm_img.cuda()
    with torch.no_grad():
예제 #8
0
def visbodyPrediction(img_in,
                      prediction,
                      options,
                      path_object,
                      device='cuda',
                      ind=0):

    prediction = prediction[0]

    # <==== vis predicted body and displacements in 3D
    # displacements mean and std
    dispPara = np.load(options.MGN_offsMeanStd_path)

    # SMPLD model
    smplD = Smpl(options.smpl_model_path)

    # gt body and offsets
    gt_offsets_t = np.load(pjn(path_object, 'gt_offsets/offsets_std.npy'))
    pathRegistr = pjn(path_object, 'registration.pkl')
    registration = pickle.load(open(pathRegistr, 'rb'), encoding='iso-8859-1')
    gtBody_p = create_smplD_psbody(smplD,
                                   gt_offsets_t,
                                   registration['pose'],
                                   registration['betas'],
                                   0,
                                   rtnMesh=True)[1]

    # naked posed body
    nakedBody_p = create_smplD_psbody(
        smplD,
        0,
        prediction['theta'][ind][3:75][None].cpu(),
        prediction['theta'][ind][75:][None].cpu(),
        0,
        rtnMesh=True)[1]

    # offsets in t-pose
    displacements = prediction['verts_disp'].cpu().numpy()[ind]
    offPred_t = (displacements * dispPara[1] + dispPara[0])

    # create predicted dressed body
    dressbody_p = create_smplD_psbody(
        smplD,
        offPred_t,
        prediction['theta'][ind][3:75].cpu().numpy(),
        prediction['theta'][ind][75:].cpu().numpy(),
        0,
        rtnMesh=True)[1]

    mvs = MeshViewers((1, 3))
    mvs[0][0].set_static_meshes([gtBody_p])
    mvs[0][1].set_static_meshes([nakedBody_p])
    mvs[0][2].set_static_meshes([dressbody_p])

    offset_p = torch.tensor(dressbody_p.v - nakedBody_p.v).to(device).float()

    # <==== vis the overall prediction, i.e. render the image

    dispPara = torch.tensor(dispPara).to(device)

    # smpl Mesh
    mesh = Mesh(options, options.num_downsampling)
    faces = torch.cat(options.batch_size * [mesh.faces.to(device)[None]],
                      dim=0)
    faces = faces.type(torch.LongTensor).to(device)

    # read SMPL .bj file to get uv coordinates
    _, smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(options.smpl_objfile_path)
    uv_coord[:, 1] = 1 - uv_coord[:, 1]
    expUV = uv_coord[tri_uv_ind.flatten()]
    unique, index = np.unique(smpl_tri_ind.flatten(), return_index=True)
    smpl_verts_uvs = torch.as_tensor(expUV[index, :]).float().to(device)
    smpl_tri_ind = torch.as_tensor(smpl_tri_ind).to(device)

    # vis texture
    vis_renderer = simple_renderer(batch_size=1)
    predTrans = torch.stack([
        prediction['theta'][ind, 1], prediction['theta'][ind, 2], 2 * 1000. /
        (224. * prediction['theta'][ind, 0] + 1e-9)
    ],
                            dim=-1)
    tex = prediction['tex_image'][ind].flip(dims=(0, ))[None]
    pred_img = vis_renderer(verts=prediction['verts'][ind][None] + offset_p,
                            faces=faces[ind][None],
                            verts_uvs=smpl_verts_uvs[None],
                            faces_uvs=smpl_tri_ind[None],
                            tex_image=tex,
                            R=torch.eye(3)[None].to('cuda'),
                            T=predTrans,
                            f=torch.ones([1, 1]).to('cuda') * 1000,
                            C=torch.ones([1, 2]).to('cuda') * 112,
                            imgres=224).cpu()
    overlayimg = 0.9 * pred_img[0, :, :, :3] + 0.1 * img_in.permute(1, 2, 0)

    if 'tex_image' in prediction.keys():
        plt.figure()
        plt.imshow(prediction['tex_image'][ind].cpu())
        plt.figure()
        plt.imshow(prediction['unwarp_tex'][ind].cpu())
        plt.figure()
        plt.imshow(pred_img[ind].cpu())
        plt.figure()
        plt.imshow(overlayimg.cpu())
        plt.figure()
        plt.imshow(img_in.cpu().permute(1, 2, 0))
예제 #9
0
def inference_structure(pathCkp: str,
                        pathImg: str = None,
                        pathBgImg: str = None):

    print('If trained locally and renamed the workspace, do not for get to '
          'change the "checkpoint_dir" in config.json. ')

    # Load configuration
    with open(pjn(pathCkp, 'config.json'), 'r') as f:
        options = json.load(f)
        options = namedtuple('options', options.keys())(**options)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    mesh = Mesh(options, options.num_downsampling)

    # read SMPL .bj file to get uv coordinates
    _, smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(options.smpl_objfile_path)
    uv_coord[:, 1] = 1 - uv_coord[:, 1]
    expUV = uv_coord[tri_uv_ind.flatten()]
    unique, index = np.unique(smpl_tri_ind.flatten(), return_index=True)
    smpl_verts_uvs = torch.as_tensor(expUV[index, :]).float().to(device)
    smpl_tri_ind = torch.as_tensor(smpl_tri_ind).to(device)

    # load average pose and shape and convert to camera coodinate;
    # avg pose is decided by the image id we use for training (0-11)
    avgPose_objCoord = np.load(options.MGN_avgPose_path)
    avgPose_objCoord[:3] = rotationMatrix_to_axisAngle(  # for 0,6, front only
        torch.tensor([[[1, 0, 0], [0, -1, 0], [0, 0, -1]]]))
    avgPose = \
        axisAngle_to_Rot6d(
            torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3)
            ).reshape(1, -1).to(device)
    avgBeta = \
        torch.Tensor(
            np.load(options.MGN_avgBeta_path)[None]).to(device)
    avgCam = torch.Tensor([1.2755, 0,
                           0])[None].to(device)  # 1.2755 is for our settings

    # Create model
    model = frameVIBE(options.smpl_model_path, mesh, avgPose, avgBeta, avgCam,
                      options.num_channels, options.num_layers, smpl_verts_uvs,
                      smpl_tri_ind).to(device)

    optimizer = torch.optim.Adam(params=list(model.parameters()))
    models_dict = {options.model: model}
    optimizers_dict = {'optimizer': optimizer}

    # Load pretrained model
    saver = CheckpointSaver(save_dir=options.checkpoint_dir)
    saver.load_checkpoint(models_dict,
                          optimizers_dict,
                          checkpoint_file=options.checkpoint)

    # Prepare and preprocess input image
    pathToObj = '/'.join(pathImg.split('/')[:-2])
    cameraIdx = int(pathImg.split('/')[-1].split('_')[0][6:])
    with open(
            pjn(pathToObj,
                'rendering/camera%d_boundingbox.txt' % (cameraIdx))) as f:
        boundbox = literal_eval(f.readline())
    IMG_NORM_MEAN = [0.485, 0.456, 0.406]
    IMG_NORM_STD = [0.229, 0.224, 0.225]
    normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)

    path_to_rendering = '/'.join(pathImg.split('/')[:-1])
    cameraPath, lightPath = pathImg.split('/')[-1].split('_')[:2]
    cameraIdx, _ = int(cameraPath[6:]), int(lightPath[5:])
    with open(pjn(path_to_rendering,
                  'camera%d_boundingbox.txt' % (cameraIdx))) as f:
        boundbox = literal_eval(f.readline())
    img = cv2.imread(pathImg)[:, :, ::-1].astype(np.float32)

    # prepare background
    if options.replace_background:
        if pathBgImg is None:
            bgimages = []
            for subfolder in sorted(
                    glob(pjn(options.bgimg_dir, 'images/validation/*'))):
                for subsubfolder in sorted(glob(pjn(subfolder, '*'))):
                    if 'room' in subsubfolder:
                        bgimages += sorted(glob(pjn(subsubfolder, '*.jpg')))
            bgimg = cv2.imread(bgimages[np.random.randint(
                0, len(bgimages))])[:, :, ::-1].astype(np.float32)
        else:
            bgimg = cv2.imread(pathBgImg)[:, :, ::-1].astype(np.float32)
        img = background_replacing(img, bgimg)

    # augment image
    center = [(boundbox[0] + boundbox[2]) / 2, (boundbox[1] + boundbox[3]) / 2]
    scale = max((boundbox[2] - boundbox[0]) / 200,
                (boundbox[3] - boundbox[1]) / 200)
    img = torch.Tensor(crop(img, center, scale, [224, 224], rot=0)).permute(
        2, 0, 1) / 255
    img_in = normalize_img(img)

    # Inference
    with torch.no_grad():  # disable grad
        model.eval()
        prediction = model(
            img_in[None].repeat_interleave(options.batch_size,
                                           dim=0).to(device),
            img[None].repeat_interleave(options.batch_size, dim=0).to(device))

    return prediction, img_in, options