Exemple #1
0
def camera_fitting_loss(model_keypoints,
                        rotation,
                        camera_t,
                        focal_length,
                        camera_center,
                        keypoints_2d,
                        keypoints_conf,
                        distortion=None):

    # Project model keypoints
    projected_keypoints = perspective_projection(model_keypoints, rotation,
                                                 camera_t, focal_length,
                                                 camera_center, distortion)

    # Disable Wing Tips
    keypoints_conf = keypoints_conf.detach().clone()
    keypoints_conf[:, 5:7] = 0

    # Weighted robust reprojection loss
    sigma = 50
    reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma)
    reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1)

    total_loss = reprojection_loss.sum(dim=-1)

    return total_loss.sum()
Exemple #2
0
def kpts_fitting_loss(model_keypoints,
                      focal_length,
                      camera_center,
                      keypoints_2d,
                      keypoints_conf,
                      body_pose,
                      bone_length,
                      prior_weight=1,
                      pose_init=None,
                      bone_init=None,
                      sigma=100):

    device = body_pose.device

    # Project model keypoints
    projected_keypoints = perspective_projection(model_keypoints, None, None,
                                                 focal_length, camera_center)

    # Weighted robust reprojection loss
    reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma)
    reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1)

    # If provided pose/bone initialization, constraint objective from deviation from it
    if pose_init == None or bone_init == None:
        total_loss = reprojection_loss.sum(dim=-1)

    else:
        init_loss = (body_pose - pose_init).abs().sum() + (
            bone_length - bone_init).abs().sum()
        init_loss = init_loss * prior_weight
        total_loss = reprojection_loss.sum(dim=-1) + init_loss.sum()

    return total_loss.sum()
Exemple #3
0
def reproject_keypoints(mesh_keypoints, frames):

    cam_rot, cam_t, focal, center, distortion = mutils.get_cam(frames)

    kpts = mesh_keypoints.repeat([len(frames), 1, 1])
    proj_kpts = perspective_projection(kpts, cam_rot, cam_t, focal, center,
                                       distortion)

    return proj_kpts
Exemple #4
0
def body_fitting_loss(model_keypoints,
                      rotation,
                      camera_t,
                      focal_length,
                      camera_center,
                      keypoints_2d,
                      keypoints_conf,
                      body_pose,
                      bone_length,
                      sigma=50,
                      lim_weight=1,
                      prior_weight=1,
                      bone_weight=1,
                      distortion=None,
                      pose_init=None,
                      bone_init=None):

    # Project model keypoints
    device = body_pose.device
    projected_keypoints = perspective_projection(model_keypoints, rotation,
                                                 camera_t, focal_length,
                                                 camera_center, distortion)

    # Weighted robust reprojection loss
    reprojection_error = gmof(projected_keypoints - keypoints_2d, sigma)
    reprojection_loss = (keypoints_conf**2) * reprojection_error.sum(dim=-1)

    # Joint angle limit loss
    max_lim = torch.tensor(constants.max_lim).repeat(1, 1).to(device)
    min_lim = torch.tensor(constants.min_lim).repeat(1, 1).to(device)
    lim_loss = (body_pose - max_lim).clamp(
        0, float("Inf")) + (min_lim - body_pose).clamp(0, float("Inf"))
    lim_loss = lim_weight * lim_loss

    # Prior Loss
    if pose_init == None or bone_init == None:
        prior_loss = body_pose.abs()
        prior_loss = prior_weight * prior_loss
    else:
        prior_loss = (body_pose - pose_init).abs().sum() + (
            bone_length - bone_init).abs().sum()
        prior_loss = prior_weight * prior_loss

    # Bone Length Limit Loss
    max_bone = torch.tensor(constants.max_bone).repeat(1, 1).to(device)
    min_bone = torch.tensor(constants.min_bone).repeat(1, 1).to(device)
    bone_loss = (bone_length - max_bone).clamp(
        0, float("Inf")) + (min_bone - bone_length).clamp(0, float("Inf"))
    bone_loss = bone_weight * bone_loss

    total_loss = (reprojection_loss.sum(dim=-1) + lim_loss.sum() +
                  prior_loss.sum() + bone_loss.sum())

    return total_loss.sum()
Exemple #5
0
def projection(cam, s3d, eps=1e-9):
    cam_t = torch.stack([
        cam[:, 1], cam[:, 2], 2 * constants.FOCAL_LENGTH /
        (constants.IMG_RES * cam[:, 0] + eps)
    ],
                        dim=-1)
    camera_center = torch.zeros(s3d.shape[0], 2, device=device)
    s2d = perspective_projection(s3d,
                                 rotation=torch.eye(
                                     3, device=device).unsqueeze(0).expand(
                                         s3d.shape[0], -1, -1),
                                 translation=cam_t,
                                 focal_length=constants.FOCAL_LENGTH,
                                 camera_center=camera_center)
    s2d_norm = s2d / (constants.IMG_RES / 2.)  # to [-1,1]
    return {'ori': s2d, 'normed': s2d_norm}
Exemple #6
0
def camera_fitting_loss(model_joints,
                        camera_t,
                        camera_t_est,
                        camera_center,
                        joints_2d,
                        joints_conf,
                        focal_length=5000,
                        depth_loss_weight=100):
    """
    Loss function for camera optimization.
    """

    # Project model joints
    batch_size = model_joints.shape[0]
    rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(
        batch_size, -1, -1)
    #print(model_joints.size(),rotation.size(),camera_t.size(),camera_center.size())
    projected_joints = perspective_projection(model_joints, rotation, camera_t,
                                              focal_length, camera_center)

    op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder']
    op_joints_ind = [constants.JOINT_IDS[joint] for joint in op_joints]
    gt_joints = ['Right Hip', 'Left Hip', 'Right Shoulder', 'Left Shoulder']
    gt_joints_ind = [constants.JOINT_IDS[joint] for joint in gt_joints]
    #print(op_joints_ind,gt_joints_ind)
    #print(joints_2d[:, gt_joints_ind])
    #input()
    reprojection_error_op = (joints_2d[:, op_joints_ind] -
                             projected_joints[:, op_joints_ind])**2
    reprojection_error_gt = (joints_2d[:, gt_joints_ind] -
                             projected_joints[:, gt_joints_ind])**2
    #print('joint_2d',joints_2d[:, gt_joints_ind])
    #print('projected_2d',projected_joints[:, gt_joints_ind])
    # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections
    # OpenPose joints are more reliable for this task, so we prefer to use them if possible
    is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] >
                0).float()
    reprojection_loss = (is_valid * reprojection_error_op +
                         (1 - is_valid) * reprojection_error_gt).sum(dim=(1,
                                                                          2))

    # Loss that penalizes deviation from depth estimate
    depth_loss = (depth_loss_weight**
                  2) * (camera_t[:, 2] - camera_t_est[:, 2])**2

    total_loss = reprojection_loss + depth_loss
    return total_loss.sum()
Exemple #7
0
def body_fitting_loss(body_pose,
                      betas,
                      model_joints,
                      camera_t,
                      camera_center,
                      joints_2d,
                      joints_conf,
                      pose_prior,
                      focal_length=5000,
                      sigma=100,
                      pose_prior_weight=4.78,
                      shape_prior_weight=5,
                      angle_prior_weight=15.2,
                      output='sum'):
    """
    Loss function for body fitting
    """

    batch_size = body_pose.shape[0]
    rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(
        batch_size, -1, -1)
    projected_joints = perspective_projection(model_joints, rotation, camera_t,
                                              focal_length, camera_center)

    # Weighted robust reprojection error
    reprojection_error = gmof(projected_joints - joints_2d, sigma)
    reprojection_loss = (joints_conf**2) * reprojection_error.sum(dim=-1)

    # Pose prior loss
    pose_prior_loss = (pose_prior_weight**2) * pose_prior(body_pose, betas)

    # Angle prior for knees and elbows
    angle_prior_loss = (angle_prior_weight**
                        2) * angle_prior(body_pose).sum(dim=-1)

    # Regularizer to prevent betas from taking large values
    shape_prior_loss = (shape_prior_weight**2) * (betas**2).sum(dim=-1)

    total_loss = reprojection_loss.sum(
        dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss

    if output == 'sum':
        return total_loss.sum()
    elif output == 'reprojection':
        return reprojection_loss
Exemple #8
0
def body_fitting_loss_smplify_x(body_pose,
                                betas,
                                pose_embedding,
                                camera_t,
                                camera_center,
                                model_joints,
                                joints_conf,
                                joints_2d,
                                focal_length=5000,
                                sigma=100,
                                body_pose_weight=4.78,
                                shape_prior_weight=5,
                                angle_prior_weight=15.2,
                                output='sum'):
    batch_size = body_pose.shape[0]
    rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(
        batch_size, -1, -1)
    projected_joints = perspective_projection(model_joints, rotation, camera_t,
                                              focal_length, camera_center)

    # Weighted robust reprojection error
    reprojection_error = gmof(projected_joints - joints_2d, sigma)
    reprojection_loss = (joints_conf**2) * reprojection_error.sum(dim=-1)

    pose_prior_loss = (pose_embedding.pow(2).sum() * body_pose_weight**2)
    shape_prior_loss = (shape_prior_weight**2) * (betas**2).sum(dim=-1)
    angle_prior_loss = (angle_prior_weight**
                        2) * angle_prior(body_pose).sum(dim=-1)

    total_loss = reprojection_loss.sum(
        dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss

    if output == 'sum':
        return total_loss.sum()
    elif output == 'reprojection':
        return reprojection_loss
Exemple #9
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].byte(
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        )  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary
        opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
                                              rot_angle.cpu(),
                                              is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        if opt_vertices.shape != (self.options.batch_size, 6890, 3):
            opt_vertices = torch.zeros_like(opt_vertices, device=self.device)
        opt_joints = opt_output.joints

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res *
            torch.ones(batch_size, 2, device=self.device),
            gt_keypoints_2d_orig).mean(dim=-1)

        # Feed images in the network to predict camera and SMPL parameters
        pred_rotmat, pred_betas, pred_camera = self.model(images)

        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        pred_vertices = pred_output.vertices
        if pred_vertices.shape != (self.options.batch_size, 6890, 3):
            pred_vertices = torch.zeros_like(pred_vertices, device=self.device)

        pred_joints = pred_output.joints

        # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
        # This camera translation can be used in a full perspective projection
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)

        camera_center = torch.zeros(batch_size, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        # Normalize keypoints to [-1,1]
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

        if self.options.run_smplify:

            # Convert predicted rotation matrices to axis-angle
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size, -1)
            # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, so manually hack it
            pred_pose[torch.isnan(pred_pose)] = 0.0

            # Run SMPLify optimization starting from the network prediction
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)

            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)

            opt_joint_loss[update] = new_opt_joint_loss[update]
            opt_vertices[update, :] = new_opt_vertices[update, :]
            opt_joints[update, :] = new_opt_joints[update, :]
            opt_pose[update, :] = new_opt_pose[update, :]
            opt_betas[update, :] = new_opt_betas[update, :]
            opt_cam_t[update, :] = new_opt_cam_t[update, :]

            self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(),
                            is_flipped.cpu(),
                            update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())

        else:
            update = torch.zeros(batch_size, device=self.device).byte()

        # Replace extreme betas with zero betas
        opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.

        # Replace the optimized parameters with the ground truth parameters, if available
        opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
        opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
        opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)
        # Add the examples with GT parameters to the list of valid fits
        # print(valid_fit.dtype)
        valid_fit = valid_fit.to(torch.uint8)
        valid_fit = valid_fit | has_smpl

        opt_keypoints_2d = perspective_projection(
            opt_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=opt_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)

        opt_keypoints_2d = opt_keypoints_2d / (self.options.img_res / 2.)

        # Compute loss on SMPL parameters
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)

        # Compute 2D reprojection loss for the keypoints
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            self.options.openpose_train_weight,
                                            self.options.gt_train_weight)

        # Compute 3D keypoint loss
        loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints,
                                                  has_pose_3d)

        # Per-vertex loss for the shape
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)

        # Compute total loss
        # The last component is a loss that forces the network to predict positive depth values
        loss = self.options.shape_loss_weight * loss_shape +\
               self.options.keypoint_loss_weight * loss_keypoints +\
               self.options.keypoint_loss_weight * loss_keypoints_3d +\
               loss_regr_pose + self.options.beta_loss_weight * loss_regr_betas +\
               ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()
        loss *= 60

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        return output, losses
Exemple #10
0
    def train_step(self, input_batch):
        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].to(
            torch.bool
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].to(
            torch.bool)  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # Get current best fits from the dictionary
        opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
                                              rot_angle.cpu(),
                                              is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)

        # Replace extreme betas with zero betas
        opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.
        # Replace the optimized parameters with the ground truth parameters, if available
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints

        input_batch['verts'] = opt_vertices

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        # get fitted smpl parameters as pseudo ground truth
        valid_fit = self.fits_dict.get_vaild_state(
            dataset_name, indices.cpu()).to(torch.bool).to(self.device)

        try:
            valid_fit = valid_fit | has_smpl
        except RuntimeError:
            valid_fit = (valid_fit.byte() | has_smpl.byte()).to(torch.bool)

        # Render Dense Correspondences
        if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON:
            gt_cam_t_nr = opt_cam_t.detach().clone()
            gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device)
            gt_camera[:, 1:] = gt_cam_t_nr[:, :2]
            gt_camera[:, 0] = (2. * self.focal_length /
                               self.options.img_res) / gt_cam_t_nr[:, 2]
            iuv_image_gt = torch.zeros(
                (batch_size, 3, cfg.MODEL.PyMAF.DP_HEATMAP_SIZE,
                 cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)).to(self.device)
            if torch.sum(valid_fit.float()) > 0:
                iuv_image_gt[valid_fit] = self.iuv_maker.verts2iuvimg(
                    opt_vertices[valid_fit],
                    cam=gt_camera[valid_fit])  # [B, 3, 56, 56]
            input_batch['iuv_image_gt'] = iuv_image_gt

            uvia_list = iuv_img2map(iuv_image_gt)

        # Feed images in the network to predict camera and SMPL parameters
        if self.options.regressor == 'hmr':
            pred_rotmat, pred_betas, pred_camera = self.model(images)
            # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
        elif self.options.regressor == 'pymaf_net':
            preds_dict, _ = self.model(images)

        output = preds_dict
        loss_dict = {}

        if self.options.regressor == 'pymaf_net' and cfg.MODEL.PyMAF.AUX_SUPV_ON:
            dp_out = preds_dict['dp_out']
            for i in range(len(dp_out)):
                r_i = i - len(dp_out)

                u_pred, v_pred, index_pred, ann_pred = dp_out[r_i][
                    'predict_u'], dp_out[r_i]['predict_v'], dp_out[r_i][
                        'predict_uv_index'], dp_out[r_i]['predict_ann_index']
                if index_pred.shape[-1] == iuv_image_gt.shape[-1]:
                    uvia_list_i = uvia_list
                else:
                    iuv_image_gt_i = F.interpolate(iuv_image_gt,
                                                   u_pred.shape[-1],
                                                   mode='nearest')
                    uvia_list_i = iuv_img2map(iuv_image_gt_i)

                loss_U, loss_V, loss_IndexUV, loss_segAnn = self.body_uv_losses(
                    u_pred, v_pred, index_pred, ann_pred, uvia_list_i,
                    valid_fit)
                loss_dict[f'loss_U{r_i}'] = loss_U
                loss_dict[f'loss_V{r_i}'] = loss_V
                loss_dict[f'loss_IndexUV{r_i}'] = loss_IndexUV
                loss_dict[f'loss_segAnn{r_i}'] = loss_segAnn

        len_loop = len(preds_dict['smpl_out']
                       ) if self.options.regressor == 'pymaf_net' else 1

        for l_i in range(len_loop):

            if self.options.regressor == 'pymaf_net':
                if l_i == 0:
                    # initial parameters (mean poses)
                    continue
                pred_rotmat = preds_dict['smpl_out'][l_i]['rotmat']
                pred_betas = preds_dict['smpl_out'][l_i]['theta'][:, 3:13]
                pred_camera = preds_dict['smpl_out'][l_i]['theta'][:, :3]

            pred_output = self.smpl(betas=pred_betas,
                                    body_pose=pred_rotmat[:, 1:],
                                    global_orient=pred_rotmat[:,
                                                              0].unsqueeze(1),
                                    pose2rot=False)
            pred_vertices = pred_output.vertices
            pred_joints = pred_output.joints

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
                (self.options.img_res * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(batch_size, 2, device=self.device)
            pred_keypoints_2d = perspective_projection(
                pred_joints,
                rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1),
                translation=pred_cam_t,
                focal_length=self.focal_length,
                camera_center=camera_center)
            # Normalize keypoints to [-1,1]
            pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

            # Compute loss on SMPL parameters
            loss_regr_pose, loss_regr_betas = self.smpl_losses(
                pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)
            loss_regr_pose *= cfg.LOSS.POSE_W
            loss_regr_betas *= cfg.LOSS.SHAPE_W
            loss_dict['loss_regr_pose_{}'.format(l_i)] = loss_regr_pose
            loss_dict['loss_regr_betas_{}'.format(l_i)] = loss_regr_betas

            # Compute 2D reprojection loss for the keypoints
            if cfg.LOSS.KP_2D_W > 0:
                loss_keypoints = self.keypoint_loss(
                    pred_keypoints_2d, gt_keypoints_2d,
                    self.options.openpose_train_weight,
                    self.options.gt_train_weight) * cfg.LOSS.KP_2D_W
                loss_dict['loss_keypoints_{}'.format(l_i)] = loss_keypoints

            # Compute 3D keypoint loss
            loss_keypoints_3d = self.keypoint_3d_loss(
                pred_joints, gt_joints, has_pose_3d) * cfg.LOSS.KP_3D_W
            loss_dict['loss_keypoints_3d_{}'.format(l_i)] = loss_keypoints_3d

            # Per-vertex loss for the shape
            if cfg.LOSS.VERT_W > 0:
                loss_shape = self.shape_loss(pred_vertices, opt_vertices,
                                             valid_fit) * cfg.LOSS.VERT_W
                loss_dict['loss_shape_{}'.format(l_i)] = loss_shape

            # Camera
            # force the network to predict positive depth values
            loss_cam = ((torch.exp(-pred_camera[:, 0] * 10))**2).mean()
            loss_dict['loss_cam_{}'.format(l_i)] = loss_cam

        for key in loss_dict:
            if len(loss_dict[key].shape) > 0:
                loss_dict[key] = loss_dict[key][0]

        # Compute total loss
        loss = torch.stack(list(loss_dict.values())).sum()

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments for tensorboard logging
        output.update({
            'pred_vertices': pred_vertices.detach(),
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t.detach(),
            'opt_cam_t': opt_cam_t
        })
        loss_dict['loss'] = loss.detach().item()

        if self.step_count % 100 == 0:
            if self.options.multiprocessing_distributed:
                for loss_name, val in loss_dict.items():
                    val = val / self.options.world_size
                    if not torch.is_tensor(val):
                        val = torch.Tensor([val]).to(self.device)
                    dist.all_reduce(val)
                    loss_dict[loss_name] = val
            if self.options.rank == 0:
                for loss_name, val in loss_dict.items():
                    self.summary_writer.add_scalar(
                        'losses/{}'.format(loss_name), val, self.step_count)

        return {'preds': output, 'losses': loss_dict}
Exemple #11
0
    def _forward(self, in_dict):
        iuv_map = in_dict['iuv_map']
        part_iuv_map = in_dict[
            'part_iuv_map'] if 'part_iuv_map' in in_dict else None
        infer_mode = in_dict['infer_mode'] if 'infer_mode' in in_dict else False

        if cfg.DANET.INPUT_MODE in ['feat', 'iuv_feat', 'iuv_gt_feat']:
            device_id = iuv_map['feat'].get_device()
        elif cfg.DANET.INPUT_MODE == 'seg':
            device_id = iuv_map['index'].get_device()
        else:
            device_id = iuv_map.get_device()
        return_dict = {}
        return_dict['losses'] = {}
        return_dict['metrics'] = {}
        return_dict['visualization'] = {}
        return_dict['prediction'] = {}

        if cfg.DANET.DECOMPOSED:
            smpl_out_dict = self.smpl_para_Outs(iuv_map, part_iuv_map)
        else:
            smpl_out_dict = self.smpl_para_Outs(iuv_map)

        if infer_mode:
            return smpl_out_dict

        para = smpl_out_dict['para']

        for k, v in smpl_out_dict['visualization'].items():
            return_dict['visualization'][k] = v

        return_dict['prediction']['cam'] = para[:, :3]
        return_dict['prediction']['shape'] = para[:, 3:13]
        return_dict['prediction']['pose'] = para[:,
                                                 13:].reshape(-1, 24, 3,
                                                              3).contiguous()

        # losses for Training
        if self.training:
            batch_size = len(para)

            target = in_dict['target']
            target_kps = in_dict['target_kps']
            target_kps3d = in_dict['target_kps3d']
            target_vertices = in_dict['target_verts']
            has_kp3d = in_dict['has_kp3d']
            has_smpl = in_dict['has_smpl']

            if cfg.DANET.ORTHOGONAL_WEIGHTS > 0:
                loss_orth = self.orthogonal_loss(para)
                loss_orth *= cfg.DANET.ORTHOGONAL_WEIGHTS
                return_dict['losses']['Rs_orth'] = loss_orth
                return_dict['metrics']['orth'] = loss_orth.detach()

            if len(smpl_out_dict['joint_rotation']) > 0:
                for stack_i in range(len(smpl_out_dict['joint_rotation'])):
                    if torch.sum(has_smpl) > 0:
                        loss_rot = self.criterion_regr(
                            smpl_out_dict['joint_rotation'][stack_i][
                                has_smpl == 1], target[:, 13:][has_smpl == 1])
                        loss_rot *= cfg.DANET.SMPL_POSE_WEIGHTS
                    else:
                        loss_rot = torch.zeros(1).to(pred.device)

                    return_dict['losses']['joint_rotation' +
                                          str(stack_i)] = loss_rot

            if cfg.DANET.DECOMPOSED and (
                    'joint_position'
                    in smpl_out_dict) and cfg.DANET.JOINT_POSITION_WEIGHTS > 0:
                gt_beta = target[:, 3:13].contiguous().detach()
                gt_Rs = target[:, 13:].contiguous().view(-1, 24, 3, 3).detach()
                smpl_pts = self.smpl(betas=gt_beta,
                                     body_pose=gt_Rs[:, 1:],
                                     global_orient=gt_Rs[:, 0].unsqueeze(1),
                                     pose2rot=False)
                gt_smpl_coord = smpl_pts.smpl_joints
                for stack_i in range(len(smpl_out_dict['joint_position'])):
                    loss_pos = self.l1_losses(
                        smpl_out_dict['joint_position'][stack_i],
                        gt_smpl_coord, has_smpl)
                    loss_pos *= cfg.DANET.JOINT_POSITION_WEIGHTS
                    return_dict['losses']['joint_position' +
                                          str(stack_i)] = loss_pos

            pred_camera = para[:, :3]
            pred_betas = para[:, 3:13]
            pred_rotmat = para[:, 13:].reshape(-1, 24, 3, 3).contiguous()

            gt_camera = target[:, :3]
            gt_betas = target[:, 3:13]
            gt_rotmat = target[:, 13:]

            pred_output = self.smpl(betas=pred_betas,
                                    body_pose=pred_rotmat[:, 1:],
                                    global_orient=pred_rotmat[:,
                                                              0].unsqueeze(1),
                                    pose2rot=False)
            pred_vertices = pred_output.vertices
            pred_joints = pred_output.joints

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
                (cfg.DANET.INIMG_SIZE * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(batch_size, 2, device=self.device)
            pred_keypoints_2d = perspective_projection(
                pred_joints,
                rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1),
                translation=pred_cam_t,
                focal_length=self.focal_length,
                camera_center=camera_center)
            # Normalize keypoints to [-1,1]
            pred_keypoints_2d = pred_keypoints_2d / (cfg.DANET.INIMG_SIZE / 2.)

            # Compute loss on predicted camera
            loss_cam = self.l1_losses(pred_camera, gt_camera, has_smpl)

            # Compute loss on SMPL parameters
            loss_regr_pose, loss_regr_betas = self.smpl_losses(
                pred_rotmat, pred_betas, gt_rotmat, gt_betas, has_smpl)

            # Compute 2D reprojection loss for the keypoints
            loss_keypoints = self.keypoint_loss(
                pred_keypoints_2d, target_kps,
                self.options.openpose_train_weight,
                self.options.gt_train_weight)

            # Compute 3D keypoint loss
            loss_keypoints_3d = self.keypoint_3d_loss(pred_joints,
                                                      target_kps3d, has_kp3d)

            # Per-vertex loss for the shape
            loss_verts = self.shape_loss(pred_vertices, target_vertices,
                                         has_smpl)

            # The last component is a loss that forces the network to predict positive depth values
            return_dict['losses'].update({
                'keypoints_2d':
                loss_keypoints * cfg.DANET.PROJ_KPS_WEIGHTS,
                'keypoints_3d':
                loss_keypoints_3d * cfg.DANET.KPS3D_WEIGHTS,
                'smpl_pose':
                loss_regr_pose * cfg.DANET.SMPL_POSE_WEIGHTS,
                'smpl_betas':
                loss_regr_betas * cfg.DANET.SMPL_BETAS_WEIGHTS,
                'smpl_verts':
                loss_verts * cfg.DANET.VERTS_WEIGHTS,
                'cam': ((torch.exp(-pred_camera[:, 0] * 10))**2).mean()
            })

            return_dict['prediction']['vertices'] = pred_vertices
            return_dict['prediction']['cam_t'] = pred_cam_t

        # handle bug on gathering scalar(0-dim) tensors
        for k, v in return_dict['losses'].items():
            if len(v.shape) == 0:
                return_dict['losses'][k] = v.unsqueeze(0)
        for k, v in return_dict['metrics'].items():
            if len(v.shape) == 0:
                return_dict['metrics'][k] = v.unsqueeze(0)

        return return_dict
Exemple #12
0
def run_evaluation(model, dataset_name, dataset, result_file,
                   batch_size=32, img_res=224, 
                   num_workers=32, shuffle=False, log_freq=50, options=None):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)
    smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                     gender='male',
                     create_transl=False).to(device)
    smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                       gender='female',
                       create_transl=False).to(device)
    
    renderer = PartRenderer()
    
    # Regressor for H36m joints
    J_regressor = torch.from_numpy(np.load(path_config.JOINT_REGRESSOR_H36M)).float()
    
    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    fits_dict = None

    # Pose metrics
    # MPJPE and Reconstruction error for the non-parametric and parametric shapes
    mpjpe = np.zeros(len(dataset))
    recon_err = np.zeros(len(dataset))
    mpjpe_smpl = np.zeros(len(dataset))
    recon_err_smpl = np.zeros(len(dataset))

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    # joint_mapper_coco = constants.H36M_TO_JCOCO
    joint_mapper_gt = constants.J24_TO_JCOCO

    focal_length = 5000

    num_joints = 17
    num_samples = len(dataset)
    print('dataset length: {}'.format(num_samples))
    all_preds = np.zeros(
        (num_samples, num_joints, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        for step, batch in enumerate(tqdm(data_loader, desc='Eval', total=len(data_loader))):
            if len(options.vis_imname) > 0:
                imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]
                name_hit = False
                for i_n in imgnames:
                    if options.vis_imname in i_n:
                        name_hit = True
                        print('vis: ' + i_n)
                if not name_hit:
                    continue

            images = batch['img'].to(device)

            scale = batch['scale'].numpy()
            center = batch['center'].numpy()

            num_images = images.size(0)

            gt_keypoints_2d = batch['keypoints']  # 2D keypoints
            # De-normalize 2D keypoints from [-1,1] to pixel space
            gt_keypoints_2d_orig = gt_keypoints_2d.clone()
            gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * (gt_keypoints_2d_orig[:, :, :-1] + 1)

            if options.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
                # torch.Size([32, 24, 3, 3]) torch.Size([32, 10]) torch.Size([32, 3])
            elif options.regressor == 'pymaf_net':
                preds_dict, _ = model(images)
                pred_rotmat = preds_dict['smpl_out'][-1]['rotmat'].contiguous().view(-1, 24, 3, 3)
                pred_betas = preds_dict['smpl_out'][-1]['theta'][:, 3:13].contiguous()
                pred_camera = preds_dict['smpl_out'][-1]['theta'][:, :3].contiguous()

            pred_output = smpl_neutral(betas=pred_betas, body_pose=pred_rotmat[:, 1:],
                                        global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False)

            # pred_vertices = pred_output.vertices
            pred_J24 = pred_output.joints[:, -24:]
            pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO]

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([pred_camera[:,1],
                                        pred_camera[:,2],
                                        2*constants.FOCAL_LENGTH/(img_res * pred_camera[:, 0] +1e-9)],dim=-1)
            camera_center = torch.zeros(len(pred_JCOCO), 2, device=pred_camera.device)
            pred_keypoints_2d = perspective_projection(pred_JCOCO,
                                                        rotation=torch.eye(3, device=pred_camera.device).unsqueeze(0).expand(len(pred_JCOCO), -1, -1),
                                                        translation=pred_cam_t,
                                                        focal_length=constants.FOCAL_LENGTH,
                                                        camera_center=camera_center)

            coords = pred_keypoints_2d + (img_res / 2.)
            coords = coords.cpu().numpy()

            gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants.J24_TO_JCOCO]
            vert_errors_batch = []
            for i, (gt2d, pred2d) in enumerate(zip(gt_keypoints_coco.cpu().numpy(), coords.copy())):
                vert_error = np.sqrt(np.sum((gt2d[:, :2] - pred2d[:, :2]) ** 2, axis=1))
                vert_error *= gt2d[:, 2]
                vert_mean_error = np.sum(vert_error) / np.sum(gt2d[:, 2] > 0)
                vert_errors_batch.append(10 * vert_mean_error)

            if options.vis_demo:
                imgnames = [i_n.split('/')[-1] for i_n in batch['imgname']]

                if options.regressor == 'hmr':
                    iuv_pred = None

                images_vis = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                images_vis = images_vis + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                vis_smpl_iuv(images_vis.cpu().numpy(), pred_camera.cpu().numpy(), pred_output.vertices.cpu().numpy(),
                             smpl_neutral.faces, iuv_pred,
                             vert_errors_batch, imgnames, os.path.join('./notebooks/output/demo_results', dataset_name,
                                                                            options.checkpoint.split('/')[-3]), options)

            preds = coords.copy()

            scale_ = np.array([scale, scale]).transpose()

            # Transform back
            for i in range(coords.shape[0]):
                preds[i] = transform_preds(
                    coords[i], center[i], scale_[i], [img_res, img_res]
                )

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = 1.
            all_boxes[idx:idx + num_images, 5] = 1.
            image_path.extend(batch['imgname'])

            idx += num_images

        if len(options.vis_imname) > 0:
            exit()

        if args.checkpoint is None or 'model_checkpoint.pt' in args.checkpoint:
            ckp_name = 'spin_model'
        else:
            ckp_name = args.checkpoint.split('/')
            ckp_name = ckp_name[2].split('_')[1] + '_' + ckp_name[-1].split('.')[0]
        name_values, perf_indicator = dataset.evaluate(
            cfg, all_preds, options.output_dir, all_boxes, image_path, ckp_name,
            filenames, imgnums
        )

        model_name = options.regressor
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file, pred_joints=pred_joints, pose=smpl_pose, betas=smpl_betas, camera=smpl_camera)
    def train_step(self, input_batch):

        # Learning rate decay
        if self.decay_steps_ind < len(cfg.SOLVER.STEPS) and input_batch[
                'step_count'] == cfg.SOLVER.STEPS[self.decay_steps_ind]:
            lr = self.optimizer.param_groups[0]['lr']
            lr_new = lr * cfg.SOLVER.GAMMA
            print('Decay the learning on step {} from {} to {}'.format(
                input_batch['step_count'], lr, lr_new))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr_new
            lr = self.optimizer.param_groups[0]['lr']
            assert lr == lr_new
            self.decay_steps_ind += 1

        self.model.train()

        # Get data from the batch
        images = input_batch['img']  # input image
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].byte(
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        )  # flag that indicates whether 3D pose is valid
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        batch_size = images.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # Get current pseudo labels (final fits of SPIN) from the dictionary
        opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
                                              rot_angle.cpu(),
                                              is_flipped.cpu())]
        opt_pose = opt_pose.to(self.device)
        opt_betas = opt_betas.to(self.device)

        # Replace extreme betas with zero betas
        opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.
        # Replace the optimized parameters with the ground truth parameters, if available
        opt_pose[has_smpl, :] = gt_pose[has_smpl, :]
        opt_betas[has_smpl, :] = gt_betas[has_smpl, :]

        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)

        if self.options.train_data in ['h36m_coco_itw']:
            valid_fit = self.fits_dict.get_vaild_state(dataset_name,
                                                       indices.cpu()).to(
                                                           self.device)
            valid_fit = valid_fit | has_smpl
        else:
            valid_fit = has_smpl

        # Feed images in the network to predict camera and SMPL parameters
        input_batch['opt_pose'] = opt_pose
        input_batch['opt_betas'] = opt_betas
        input_batch['valid_fit'] = valid_fit

        input_batch['dp_dict'] = {
            k: v.to(self.device) if isinstance(v, torch.Tensor) else v
            for k, v in input_batch['dp_dict'].items()
        }
        has_iuv = torch.tensor([dn not in ['dp_coco'] for dn in dataset_name],
                               dtype=torch.uint8).to(self.device)
        has_iuv = has_iuv & valid_fit
        input_batch['has_iuv'] = has_iuv
        has_dp = input_batch['has_dp']
        target_smpl_kps = torch.zeros(
            (batch_size, 24, 3)).to(opt_output.smpl_joints.device)
        target_smpl_kps[:, :, :2] = perspective_projection(
            opt_output.smpl_joints.detach().clone(),
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size, -1, -1),
            translation=opt_cam_t,
            focal_length=self.focal_length,
            camera_center=torch.zeros(batch_size, 2, device=self.device) +
            (0.5 * self.options.img_res))
        target_smpl_kps[:, :, :2] = target_smpl_kps[:, :, :2] / (
            0.5 * self.options.img_res) - 1
        target_smpl_kps[has_iuv == 1, :, 2] = 1
        target_smpl_kps[has_dp == 1] = input_batch['smpl_2dkps'][has_dp == 1]
        input_batch['target_smpl_kps'] = target_smpl_kps  # [B, 24, 3]
        input_batch['target_verts'] = opt_vertices.detach().clone(
        )  # [B, 6890, 3]

        # camera translation for neural renderer
        gt_cam_t_nr = opt_cam_t.detach().clone()
        gt_camera = torch.zeros(gt_cam_t_nr.shape).to(gt_cam_t_nr.device)
        gt_camera[:, 1:] = gt_cam_t_nr[:, :2]
        gt_camera[:, 0] = (2. * self.focal_length /
                           self.options.img_res) / gt_cam_t_nr[:, 2]
        input_batch['target_cam'] = gt_camera

        # Do forward
        danet_return_dict = self.model(input_batch)

        loss_tatal = 0
        losses_dict = {}
        for loss_key in danet_return_dict['losses']:
            loss_tatal += danet_return_dict['losses'][loss_key]
            losses_dict['loss_{}'.format(loss_key)] = danet_return_dict[
                'losses'][loss_key].detach().item()

        # Do backprop
        self.optimizer.zero_grad()
        loss_tatal.backward()
        self.optimizer.step()

        if input_batch['pretrain_mode']:
            pred_vertices = None
            pred_cam_t = None
        else:
            pred_vertices = danet_return_dict['prediction']['vertices'].detach(
            )
            pred_cam_t = danet_return_dict['prediction']['cam_t'].detach()

        # Pack output arguments for tensorboard logging
        output = {
            'pred_vertices': pred_vertices,
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t,
            'opt_cam_t': opt_cam_t,
            'visualization': danet_return_dict['visualization']
        }

        losses_dict.update({'loss_tatal': loss_tatal.detach().item()})

        return output, losses_dict
Exemple #14
0
    def train_step(self, input_batch):
        self.model.train()
        # get data from batch
        has_smpl = input_batch['has_smpl'].bool()
        has_pose_3d = input_batch['has_pose_3d'].bool()
        gt_pose1 = input_batch['pose']  # SMPL pose parameters
        gt_betas1 = input_batch['betas']  # SMPL beta parameters
        dataset_name = input_batch['dataset_name']
        indices = input_batch[
            'sample_index']  # index of example inside its dataset
        is_flipped = input_batch[
            'is_flipped']  # flag that indicates whether image was flipped during data augmentation
        rot_angle = input_batch[
            'rot_angle']  # rotation angle used for data augmentation
        #print(rot_angle)
        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_betas = torch.cat((gt_betas1, gt_betas1, gt_betas1, gt_betas1), 0)
        gt_pose = torch.cat((gt_pose1, gt_pose1, gt_pose1, gt_pose1), 0)
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices
        # Get current best fits from the dictionary
        opt_pose1, opt_betas1 = self.fits_dict[(dataset_name, indices.cpu(),
                                                rot_angle.cpu(),
                                                is_flipped.cpu())]
        opt_pose = torch.cat(
            (opt_pose1.to(self.device), opt_pose1.to(self.device),
             opt_pose1.to(self.device), opt_pose1.to(self.device)), 0)
        #print(opt_pose.device)
        #opt_betas = opt_betas.to(self.device)
        opt_betas = torch.cat(
            (opt_betas1.to(self.device), opt_betas1.to(self.device),
             opt_betas1.to(self.device), opt_betas1.to(self.device)), 0)
        opt_output = self.smpl(betas=opt_betas,
                               body_pose=opt_pose[:, 3:],
                               global_orient=opt_pose[:, :3])
        opt_vertices = opt_output.vertices
        opt_joints = opt_output.joints
        # images
        images = torch.cat((input_batch['img_0'], input_batch['img_1'],
                            input_batch['img_2'], input_batch['img_3']), 0)
        batch_size = input_batch['img_0'].shape[0]
        #input()
        # Output of CNN
        pred_rotmat, pred_betas, pred_camera = self.model(images)
        pred_output = self.smpl(betas=pred_betas,
                                body_pose=pred_rotmat[:, 1:],
                                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                                pose2rot=False)
        pred_vertices = pred_output.vertices
        pred_joints = pred_output.joints
        pred_cam_t = torch.stack([
            pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
            (self.options.img_res * pred_camera[:, 0] + 1e-9)
        ],
                                 dim=-1)
        camera_center = torch.zeros(batch_size * 4, 2, device=self.device)
        pred_keypoints_2d = perspective_projection(
            pred_joints,
            rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                batch_size * 4, -1, -1),
            translation=pred_cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)
        # 2d joint points
        gt_keypoints_2d = torch.cat(
            (input_batch['keypoints_0'], input_batch['keypoints_1'],
             input_batch['keypoints_2'], input_batch['keypoints_3']), 0)
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)
        opt_cam_t = estimate_translation(opt_joints,
                                         gt_keypoints_2d_orig,
                                         focal_length=self.focal_length,
                                         img_size=self.options.img_res)
        #input()
        opt_joint_loss = self.smplify.get_fitting_loss(
            opt_pose, opt_betas, opt_cam_t, 0.5 * self.options.img_res *
            torch.ones(batch_size * 4, 2, device=self.device),
            gt_keypoints_2d_orig).mean(dim=-1)
        if self.options.run_smplify:
            pred_rotmat_hom = torch.cat([
                pred_rotmat.detach().view(-1, 3, 3).detach(),
                torch.tensor(
                    [0, 0, 1], dtype=torch.float32, device=self.device).view(
                        1, 3, 1).expand(batch_size * 4 * 24, -1, -1)
            ],
                                        dim=-1)
            pred_pose = rotation_matrix_to_angle_axis(
                pred_rotmat_hom).contiguous().view(batch_size * 4, -1)
            pred_pose[torch.isnan(pred_pose)] = 0.0
            #pred_pose_detach = pred_pose.detach()
            #pred_betas_detach = pred_betas.detach()
            #pred_cam_t_detach = pred_cam_t.detach()
            new_opt_vertices, new_opt_joints,\
            new_opt_pose, new_opt_betas,\
            new_opt_cam_t, new_opt_joint_loss = self.smplify(
                                        pred_pose.detach(), pred_betas.detach(),
                                        pred_cam_t.detach(),
                                        0.5 * self.options.img_res * torch.ones(batch_size*4, 2, device=self.device),
                                        gt_keypoints_2d_orig)
            new_opt_joint_loss = new_opt_joint_loss.mean(dim=-1)
            # Will update the dictionary for the examples where the new loss is less than the current one
            update = (new_opt_joint_loss < opt_joint_loss)
            update1 = torch.cat((update, update, update, update), 0)
            opt_joint_loss[update] = new_opt_joint_loss[update]
            #print(opt_joints.size(),new_opt_joints.size())
            #input()
            opt_joints[update1, :] = new_opt_joints[update1, :]
            #print(opt_pose.size(),new_opt_pose.size())
            opt_betas[update1, :] = new_opt_betas[update1, :]
            opt_pose[update1, :] = new_opt_pose[update1, :]
            #print(i, opt_pose_mv[i])
            opt_vertices[update1, :] = new_opt_vertices[update1, :]
            opt_cam_t[update1, :] = new_opt_cam_t[update1, :]
        # now we comput the loss on the four images
        # Replace the optimized parameters with the ground truth parameters, if available
        #for i in range(4):
        #print('Here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1')
        has_smpl1 = torch.cat((has_smpl, has_smpl, has_smpl, has_smpl), 0)
        opt_vertices[has_smpl1, :, :] = gt_vertices[has_smpl1, :, :]
        opt_pose[has_smpl1, :] = gt_pose[has_smpl1, :]
        opt_cam_t[has_smpl1, :] = gt_cam_t[has_smpl1, :]
        opt_joints[has_smpl1, :, :] = gt_model_joints[has_smpl1, :, :]
        opt_betas[has_smpl1, :] = gt_betas[has_smpl1, :]
        #print(opt_cam_t[0:batch_size],opt_cam_t[batch_size:2*batch_size],opt_cam_t[2*batch_size:3*batch_size],opt_cam_t[3*batch_size:4*batch_size])
        # Assert whether a fit is valid by comparing the joint loss with the threshold
        valid_fit1 = (opt_joint_loss < self.options.smplify_threshold).to(
            self.device)
        # Add the examples with GT parameters to the list of valid fits
        valid_fit = torch.cat(
            (valid_fit1, valid_fit1, valid_fit1, valid_fit1), 0) | has_smpl1

        #gt_keypoints_2d = torch.cat((input_batch['keypoints_0'],input_batch['keypoints_1'],input_batch['keypoints_2'],input_batch['keypoints_3']),0)
        loss_keypoints = self.keypoint_loss(pred_keypoints_2d, gt_keypoints_2d,
                                            0, 1)
        #gt_joints = torch.cat((input_batch['pose_3d_0'],input_batch['pose_3d_1'],input_batch['pose_3d_2'],input_batch['pose_3d_3']),0)
        #loss_keypoints_3d = self.keypoint_3d_loss(pred_joints, gt_joints, torch.cat((has_pose_3d,has_pose_3d,has_pose_3d,has_pose_3d),0))
        loss_regr_pose, loss_regr_betas = self.smpl_losses(
            pred_rotmat, pred_betas, opt_pose, opt_betas, valid_fit)
        loss_shape = self.shape_loss(pred_vertices, opt_vertices, valid_fit)
        #print(loss_shape_sum,loss_keypoints_sum,loss_keypoints_3d_sum,loss_regr_pose_sum,loss_regr_betas_sum)
        #input()
        loss_all = 0 * loss_shape +\
                   5. * loss_keypoints +\
                   0. * loss_keypoints_3d +\
                   loss_regr_pose + 0.001* loss_regr_betas +\
                   ((torch.exp(-pred_camera[:,0]*10)) ** 2 ).mean()

        loss_all *= 60
        #print(loss_all)

        # Do backprop
        self.optimizer.zero_grad()
        loss_all.backward()
        self.optimizer.step()
        output = {
            'pred_vertices': pred_vertices,
            'opt_vertices': opt_vertices,
            'pred_cam_t': pred_cam_t,
            'opt_cam_t': opt_cam_t
        }
        losses = {
            'loss': loss_all.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        return output, losses
Exemple #15
0
def run_evaluation(model,
                   dataset,
                   result_file,
                   batch_size=32,
                   img_res=224,
                   num_workers=32,
                   shuffle=False,
                   options=None):
    """Run evaluation on the datasets and metrics we report in the paper. """

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Transfer model to the GPU
    model.to(device)

    # Load SMPL model
    smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                        create_transl=False).to(device)

    save_results = result_file is not None
    # Disable shuffling if you want to save the results
    if save_results:
        shuffle = False
    # Create dataloader for the dataset
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers)

    # Store SMPL parameters
    smpl_pose = np.zeros((len(dataset), 72))
    smpl_betas = np.zeros((len(dataset), 10))
    smpl_camera = np.zeros((len(dataset), 3))
    pred_joints = np.zeros((len(dataset), 17, 3))

    num_joints = 17

    num_samples = len(dataset)
    print('dataset length: {}'.format(num_samples))
    all_preds = np.zeros((num_samples, num_joints, 3), dtype=np.float32)
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        end = time.time()

        for step, batch in enumerate(
                tqdm(data_loader, desc='Eval', total=len(data_loader))):
            images = batch['img'].to(device)
            scale = batch['scale'].numpy()
            center = batch['center'].numpy()

            num_images = images.size(0)

            gt_keypoints_2d = batch['keypoints']  # 2D keypoints
            # De-normalize 2D keypoints from [-1,1] to pixel space
            gt_keypoints_2d_orig = gt_keypoints_2d.clone()
            gt_keypoints_2d_orig[:, :, :-1] = 0.5 * img_res * (
                gt_keypoints_2d_orig[:, :, :-1] + 1)

            if options.regressor == 'hmr':
                pred_rotmat, pred_betas, pred_camera = model(images)
            elif options.regressor == 'danet':
                danet_pred_dict = model.infer_net(images)
                para_pred = danet_pred_dict['para']
                pred_camera = para_pred[:, 0:3].contiguous()
                pred_betas = para_pred[:, 3:13].contiguous()
                pred_rotmat = para_pred[:, 13:].contiguous().view(-1, 24, 3, 3)

            pred_output = smpl_neutral(
                betas=pred_betas,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False)

            # pred_vertices = pred_output.vertices
            pred_J24 = pred_output.joints[:, -24:]
            pred_JCOCO = pred_J24[:, constants.J24_TO_JCOCO]

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 *
                constants.FOCAL_LENGTH / (img_res * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(len(pred_JCOCO),
                                        2,
                                        device=pred_camera.device)
            pred_keypoints_2d = perspective_projection(
                pred_JCOCO,
                rotation=torch.eye(
                    3, device=pred_camera.device).unsqueeze(0).expand(
                        len(pred_JCOCO), -1, -1),
                translation=pred_cam_t,
                focal_length=constants.FOCAL_LENGTH,
                camera_center=camera_center)

            coords = pred_keypoints_2d + (img_res / 2.)
            coords = coords.cpu().numpy()
            # Normalize keypoints to [-1,1]
            # pred_keypoints_2d = pred_keypoints_2d / (img_res / 2.)

            gt_keypoints_coco = gt_keypoints_2d_orig[:, -24:][:, constants.
                                                              J24_TO_JCOCO]

            preds = coords.copy()

            scale_ = np.array([scale, scale]).transpose()

            # Transform back
            for i in range(coords.shape[0]):
                preds[i] = transform_preds(coords[i], center[i], scale_[i],
                                           [img_res, img_res])

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = 1.
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = center[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = scale_[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(scale_ * 200, 1)
            all_boxes[idx:idx + num_images, 5] = 1.
            image_path.extend(batch['imgname'])

            idx += num_images

        ckp_name = options.regressor
        name_values, perf_indicator = dataset.evaluate(all_preds,
                                                       options.output_dir,
                                                       all_boxes, image_path,
                                                       ckp_name, filenames,
                                                       imgnums)

        model_name = options.regressor
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

    # Save reconstructions to a file for further processing
    if save_results:
        np.savez(result_file,
                 pred_joints=pred_joints,
                 pose=smpl_pose,
                 betas=smpl_betas,
                 camera=smpl_camera)
Exemple #16
0
    def train_step(self, input_batch):
        self.model.train()

        images_hr = input_batch['img_hr']
        images_lr_list = input_batch['img_lr']
        images_list = [images_hr] + images_lr_list
        scale_names = ['224', '224_128', '128_64', '64_40', '40_24']
        scale_names = scale_names[:len(images_list)]
        feat_names = ['layer4']

        # Get data from the batch
        gt_keypoints_2d = input_batch['keypoints']  # 2D keypoints
        gt_pose = input_batch['pose']  # SMPL pose parameters
        gt_betas = input_batch['betas']  # SMPL beta parameters
        gt_joints = input_batch['pose_3d']  # 3D pose
        has_smpl = input_batch['has_smpl'].byte(
        )  # flag that indicates whether SMPL parameters are valid
        has_pose_3d = input_batch['has_pose_3d'].byte(
        )  # flag that indicates whether 3D pose is valid
        dataset_name = input_batch[
            'dataset_name']  # name of the dataset the image comes from
        indices = input_batch['sample_index'].numpy(
        )  # index of example inside mixed dataset
        batch_size = images_hr.shape[0]

        # Get GT vertices and model joints
        # Note that gt_model_joints is different from gt_joints as it comes from SMPL
        gt_out = self.smpl(betas=gt_betas,
                           body_pose=gt_pose[:, 3:],
                           global_orient=gt_pose[:, :3])
        gt_model_joints = gt_out.joints
        gt_vertices = gt_out.vertices

        # De-normalize 2D keypoints from [-1,1] to pixel space
        gt_keypoints_2d_orig = gt_keypoints_2d.clone()
        gt_keypoints_2d_orig[:, :, :-1] = 0.5 * self.options.img_res * (
            gt_keypoints_2d_orig[:, :, :-1] + 1)

        # Estimate camera translation given the model joints and 2D keypoints
        # by minimizing a weighted least squares loss
        gt_cam_t = estimate_translation(gt_model_joints,
                                        gt_keypoints_2d_orig,
                                        focal_length=self.focal_length,
                                        img_size=self.options.img_res)

        loss_shape = 0
        loss_keypoints = 0
        loss_keypoints_3d = 0
        loss_regr_pose = 0
        loss_regr_betas = 0
        loss_regr_cam_t = 0
        smpl_outputs = []
        for i, (images, scale_name) in enumerate(
                zip(images_list, scale_names[:len(images_list)])):
            images = images.to(self.device)
            # Feed images in the network to predict camera and SMPL parameters
            pred_rotmat, pred_betas, pred_camera, feat_list = self.model(
                images, scale=i)

            pred_output = self.smpl(betas=pred_betas,
                                    body_pose=pred_rotmat[:, 1:],
                                    global_orient=pred_rotmat[:,
                                                              0].unsqueeze(1),
                                    pose2rot=False)
            pred_vertices = pred_output.vertices
            pred_joints = pred_output.joints

            # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz] in 3D given the bounding box size
            # This camera translation can be used in a full perspective projection
            pred_cam_t = torch.stack([
                pred_camera[:, 1], pred_camera[:, 2], 2 * self.focal_length /
                (self.options.img_res * pred_camera[:, 0] + 1e-9)
            ],
                                     dim=-1)

            camera_center = torch.zeros(batch_size, 2, device=self.device)
            pred_keypoints_2d = perspective_projection(
                pred_joints,
                rotation=torch.eye(3, device=self.device).unsqueeze(0).expand(
                    batch_size, -1, -1),
                translation=pred_cam_t,
                focal_length=self.focal_length,
                camera_center=camera_center)
            # Normalize keypoints to [-1,1]
            pred_keypoints_2d = pred_keypoints_2d / (self.options.img_res / 2.)

            # Compute loss on SMPL parameters
            loss_pose, loss_betas, loss_cam_t = self.smpl_losses(
                pred_rotmat, pred_betas, pred_cam_t, gt_pose, gt_betas,
                gt_cam_t, has_smpl)
            loss_regr_pose = loss_regr_pose + (i + 1) * loss_pose
            loss_regr_betas = loss_regr_betas + (i + 1) * loss_betas
            loss_regr_cam_t = loss_regr_cam_t + (i + 1) * loss_cam_t

            # Compute 2D reprojection loss for the keypoints
            loss_keypoints = loss_keypoints + (i + 1) * self.keypoint_loss(
                pred_keypoints_2d, gt_keypoints_2d, self.options.
                openpose_train_weight, self.options.gt_train_weight)

            # Compute 3D keypoint loss
            loss_keypoints_3d = loss_keypoints_3d + (
                i + 1) * self.keypoint_3d_loss(pred_joints, gt_joints,
                                               has_pose_3d)

            # Per-vertex loss for the shape
            loss_shape = loss_shape + (i + 1) * self.shape_loss(
                pred_vertices, gt_vertices, has_smpl)

            # save pred_rotmat, pred_betas, pred_cam_t for later, from large images to smaller images
            smpl_outputs.append(
                [pred_rotmat, pred_betas, pred_cam_t, feat_list])

            # update queue size
            self.feat_queue.update_queue_size(batch_size)
            # update the queue
            self.feat_queue.update_all([feat.detach() for feat in feat_list],
                                       [name for name in feat_names])
            # update dataset name and index for each scale
            self.feat_queue.update('dataset_names', np.array(dataset_name))
            self.feat_queue.update('dataset_indices', indices)

        # Compute total loss except the consistency loss
        loss = self.options.shape_loss_weight * loss_shape +\
               self.options.keypoint_loss_weight * loss_keypoints + \
               self.options.keypoint_loss_weight * loss_keypoints_3d +\
               self.options.pose_loss_weight * loss_regr_pose + \
               self.options.beta_loss_weight * loss_regr_betas + \
               self.options.cam_loss_weight * loss_regr_cam_t
        loss = loss / len(images_list)

        # compute the consistency loss
        loss_consistency = 0
        for i in range(len(smpl_outputs)):
            gt_rotmat, gt_betas, gt_cam_t, gt_feat_list = smpl_outputs[i]
            gt_rotmat = gt_rotmat.detach()
            gt_betas = gt_betas.detach()
            gt_cam_t = gt_cam_t.detach()
            gt_feat_list = [feat.detach() for feat in gt_feat_list]
            # sample negative index
            indices_list = self.feat_queue.select_indices(
                dataset_name, indices, self.options.sample_size)
            neg_feat_list = self.feat_queue.batch_sample_all(indices_list,
                                                             names=feat_names)
            for j in range(i + 1, len(smpl_outputs)):
                # compute the consistency loss from high to low: 1:2, 1:3, 2:3 and weighted by 1/(j-i)
                pred_rotmat, pred_betas, pred_cam_t, pred_feat_list = smpl_outputs[
                    j]
                loss_consistency_total, loss_consistency_smpl, loss_consistency_feat = self.consistency_losses(
                    pred_rotmat, pred_betas, pred_cam_t, pred_feat_list,
                    gt_rotmat, gt_betas, gt_cam_t, gt_feat_list, neg_feat_list)
                loss_consistency = loss_consistency + (
                    (j - i) / len(smpl_outputs)) * loss_consistency_total
        loss_consistency = loss_consistency * self.consistency_loss_ramp * self.options.consistency_loss_weight

        loss += loss_consistency
        loss *= 60

        # Do backprop
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Pack output arguments
        output = {
            'pred_vertices': pred_vertices.detach(),
            'pred_cam_t': pred_cam_t.detach()
        }
        losses = {
            'lr': self.optimizer.param_groups[0]['lr'],
            'loss_ramp': self.consistency_loss_ramp,
            'loss': loss.detach().item(),
            'loss_consistency': loss_consistency.detach().item(),
            'loss_consistency_smpl': loss_consistency_smpl.detach().item(),
            'loss_consistency_feat': loss_consistency_feat.detach().item(),
            'loss_keypoints': loss_keypoints.detach().item(),
            'loss_keypoints_3d': loss_keypoints_3d.detach().item(),
            'loss_regr_pose': loss_regr_pose.detach().item(),
            'loss_regr_betas': loss_regr_betas.detach().item(),
            'loss_shape': loss_shape.detach().item()
        }

        return output, losses
Exemple #17
0
def evaluate_singleview(root, annfile, device, use_mask=False):
    """
    Function to evaluation singleview reconstruction
    """
    
    # Models and optimizer
    bird = bird_model()
    predictor = load_detector().to(device)
    regressor = load_regressor().to(device)

    if args.use_mask:
        if device == 'cpu':
            print('Warning: using mask during optimization without GPU acceleration is very slow!')
        silhouette_renderer = base_renderer(size=256, focal=2167, device=device)
        optimizer = OptimizeSV(num_iters=100, prior_weight=1, mask_weight=1, 
                               use_mask=True, renderer=silhouette_renderer, device=device)
        print('Using mask for single view optimization')
    else:
        optimizer = OptimizeSV(num_iters=100, prior_weight=1, mask_weight=1, 
                               use_mask=False, device=device)

    # Dataset to run on
    normalize = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
        ])

    dataset = Cowbird_Dataset(root=root, annfile=annfile, scale_factor=0.25, transform=normalize)
    loader = torch.utils.data.DataLoader(dataset, batch_size=30)
    Pose_, Tran_, Bone_ = [], [], []
    GT_kpts, GT_masks, Sizes = [], [], []

    # Run reconstruction
    for i, (imgs, gt_kpts, gt_masks, meta) in enumerate(loader):
        print('Running on batch:', i+1)
        with torch.no_grad():
            # Prediction
            output = predictor(imgs.to(device))
            pred_kpts, pred_mask = postprocess(output)

            # Regression
            kpts_in = pred_kpts.reshape(pred_kpts.shape[0], -1)
            mask_in = pred_mask
            p_est, b_est = regressor(kpts_in, mask_in)
            pose, tran, bone = regressor.postprocess(p_est, b_est)

        # Optimization
        ignored = pred_kpts[:, :, 2] < 0.3
        opt_kpts = pred_kpts.clone()
        opt_kpts[ignored] = 0
        pose_op, bone_op, tran_op, model_mesh = optimizer(pose, bone, tran, 
                                              focal_length=2167, camera_center=128, 
                                              keypoints=opt_kpts, masks=mask_in.squeeze(1))
        Pose_.append(pose_op)
        Tran_.append(tran_op)
        Bone_.append(bone_op)
        GT_kpts.append(gt_kpts)
        GT_masks.append(gt_masks)
        Sizes.append(meta['size'])

    Pose_ = torch.cat(Pose_)
    Tran_ = torch.cat(Tran_)
    Bone_ = torch.cat(Bone_)
    GT_kpts = torch.cat(GT_kpts)
    GT_masks = torch.cat(GT_masks)
    Sizes = torch.cat(Sizes)

    # Render reprojected kpts and masks
    kpts_3d, vertices = pose_bird(bird, Pose_[:,:3], Pose_[:,3:], Bone_, Tran_, pose2rot=True)
    kpts_2d = perspective_projection(kpts_3d, None, None, focal_length=2167, camera_center=128)
    faces = torch.tensor(bird.dd['F'])

    masks = []
    mask_renderer = Silhouette_Renderer(focal_length=2167, center=(128,128), img_w=256, img_h=256)
    for i in range(len(vertices)):
        m = mask_renderer(vertices[i], faces)
        masks.append(m)
    masks = torch.tensor(np.stack(masks)).long()


    # Evaluation
    PCK05, PCK10 = evaluate_pck(kpts_2d[:,:12,:], GT_kpts, size=Sizes)
    IOU = evaluate_iou(masks, GT_masks)

    avg_PCK05 = torch.mean(torch.stack(PCK05))
    avg_PCK10 = torch.mean(torch.stack(PCK10))
    avg_IOU = torch.mean(torch.stack(IOU))

    print('Average PCK05:', avg_PCK05)
    print('Average PCK10:', avg_PCK10)
    print('Average IOU:', avg_IOU)