def measurementError(predicted_smpl_shape, target_smpl_shape, gender, device):
    reposed_pose_rotmats, reposed_glob_rotmats = getReposedRotmats(1, device)
    faces = None
    if gender == 'm':
        smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                         gender='male').to(device)
        pred_smpl_neutral_pose_output = smpl_male(
            betas=predicted_smpl_shape,
            body_pose=reposed_pose_rotmats,
            global_orient=reposed_glob_rotmats,
            pose2rot=False)
        target_smpl_neutral_pose_output = smpl_male(
            betas=target_smpl_shape,
            body_pose=reposed_pose_rotmats,
            global_orient=reposed_glob_rotmats,
            pose2rot=False)
        faces = smpl_male.faces
    elif gender == 'f':
        smpl_female = SMPL(config.SMPL_MODEL_DIR,
                           batch_size=1,
                           gender='female').to(device)
        pred_smpl_neutral_pose_output = smpl_female(
            betas=predicted_smpl_shape,
            body_pose=reposed_pose_rotmats,
            global_orient=reposed_glob_rotmats,
            pose2rot=False)
        target_smpl_neutral_pose_output = smpl_female(
            betas=target_smpl_shape,
            body_pose=reposed_pose_rotmats,
            global_orient=reposed_glob_rotmats,
            pose2rot=False)
        faces = smpl_female.faces

    pred_smpl_neutral_pose_vertices = pred_smpl_neutral_pose_output.vertices
    target_smpl_neutral_pose_vertices = target_smpl_neutral_pose_output.vertices

    # Rescale such that RMSD of predicted vertex mesh is the same as RMSD of target mesh.
    # This is done to combat scale vs camera depth ambiguity.
    pred_smpl_neutral_pose_vertices_rescale = scale_and_translation_transform_batch(
        pred_smpl_neutral_pose_vertices, target_smpl_neutral_pose_vertices)

    # Compute PVE-T-SC
    pve_neutral_pose_scale_corrected = np.linalg.norm(
        pred_smpl_neutral_pose_vertices_rescale -
        target_smpl_neutral_pose_vertices.detach().cpu().numpy(),
        axis=-1)  # (1, 6890)

    # Measurements
    weight_pred, height_pred, chest_pred, hip_pred = getBodyMeasurement(
        pred_smpl_neutral_pose_vertices, faces)
    weight_target, height_target, chest_target, hip_target = getBodyMeasurement(
        target_smpl_neutral_pose_vertices, faces)

    weight_error = weight_target - weight_pred
    height_error = height_target - height_pred
    chest_error = chest_target - chest_pred
    hip_error = hip_target - hip_pred

    return pve_neutral_pose_scale_corrected, weight_error, height_error, chest_error, hip_error
def compute_pve_neutral_pose_scale_corrected(predicted_smpl_shape,
                                             target_smpl_shape, gender,
                                             device):
    """
    Given predicted and target SMPL shape parameters, computes neutral-pose per-vertex error
    after scale-correction (to account for scale vs camera depth ambiguity).
    :param predicted_smpl_parameters: predicted SMPL shape parameters tensor with shape (1, 10)
    :param target_smpl_parameters: target SMPL shape parameters tensor with shape (1, 10)
    :param gender: gender of target
    """

    # Get neutral pose vertices
    if gender == 'm':
        smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                         gender='male').to(device)
        pred_smpl_neutral_pose_output = smpl_male(betas=predicted_smpl_shape)
        target_smpl_neutral_pose_output = smpl_male(betas=target_smpl_shape)
    elif gender == 'f':
        smpl_female = SMPL(config.SMPL_MODEL_DIR,
                           batch_size=1,
                           gender='female').to(device)
        pred_smpl_neutral_pose_output = smpl_female(betas=predicted_smpl_shape)
        target_smpl_neutral_pose_output = smpl_female(betas=target_smpl_shape)

    pred_smpl_neutral_pose_vertices = pred_smpl_neutral_pose_output.vertices
    target_smpl_neutral_pose_vertices = target_smpl_neutral_pose_output.vertices

    # Rescale such that RMSD of predicted vertex mesh is the same as RMSD of target mesh.
    # This is done to combat scale vs camera depth ambiguity.
    pred_smpl_neutral_pose_vertices_rescale = scale_and_translation_transform_batch(
        pred_smpl_neutral_pose_vertices, target_smpl_neutral_pose_vertices)

    # Compute PVE-T-SC
    pve_neutral_pose_scale_corrected = np.linalg.norm(
        pred_smpl_neutral_pose_vertices_rescale -
        target_smpl_neutral_pose_vertices.detach().cpu().numpy(),
        axis=-1)  # (1, 6890)

    return pve_neutral_pose_scale_corrected
Пример #3
0
def predict_3D(input,
               regressor,
               device,
               silhouettes_from='densepose',
               proxy_rep_input_wh=512,
               save_proxy_vis=True,
               render_vis=True):

    # Set-up proxy representation predictors.
    joints2D_predictor, silhouette_predictor = setup_detectron2_predictors(
        silhouettes_from=silhouettes_from)

    # Set-up SMPL model.
    smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1).to(device)

    if render_vis:
        # Set-up renderer for visualisation.
        wp_renderer = Renderer(resolution=(proxy_rep_input_wh,
                                           proxy_rep_input_wh))

    if os.path.isdir(input):
        image_fnames = [
            f for f in sorted(os.listdir(input))
            if f.endswith('.png') or f.endswith('.jpg')
        ]
        for fname in image_fnames:
            print("Predicting on:", fname)
            image = cv2.imread(os.path.join(input, fname))
            # Pre-process for 2D detectors
            image = pad_to_square(image)
            image = cv2.resize(image, (proxy_rep_input_wh, proxy_rep_input_wh),
                               interpolation=cv2.INTER_LINEAR)
            # Predict 2D
            joints2D, joints2D_vis = predict_joints2D(image,
                                                      joints2D_predictor)
            if silhouettes_from == 'pointrend':
                silhouette, silhouette_vis = predict_silhouette_pointrend(
                    image, silhouette_predictor)
            elif silhouettes_from == 'densepose':
                silhouette, silhouette_vis = predict_densepose(
                    image, silhouette_predictor)
                silhouette = convert_multiclass_to_binary_labels(silhouette)
            # Crop around silhouette
            silhouette, joints2D, image = crop_and_resize_silhouette_joints(
                silhouette,
                joints2D,
                out_wh=config.REGRESSOR_IMG_WH,
                image=image,
                image_out_wh=proxy_rep_input_wh,
                bbox_scale_factor=1.2)
            # Create proxy representation
            proxy_rep = create_proxy_representation(
                silhouette, joints2D, out_wh=config.REGRESSOR_IMG_WH)
            proxy_rep = proxy_rep[None, :, :, :]  # add batch dimension
            proxy_rep = torch.from_numpy(proxy_rep).float().to(device)

            # Predict 3D
            regressor.eval()
            with torch.no_grad():
                pred_cam_wp, pred_pose, pred_shape = regressor(proxy_rep)
                # Convert pred pose to rotation matrices
                if pred_pose.shape[-1] == 24 * 3:
                    pred_pose_rotmats = batch_rodrigues(
                        pred_pose.contiguous().view(-1, 3))
                    pred_pose_rotmats = pred_pose_rotmats.view(-1, 24, 3, 3)
                elif pred_pose.shape[-1] == 24 * 6:
                    pred_pose_rotmats = rot6d_to_rotmat(
                        pred_pose.contiguous()).view(-1, 24, 3, 3)

                pred_smpl_output = smpl(
                    body_pose=pred_pose_rotmats[:, 1:],
                    global_orient=pred_pose_rotmats[:, 0].unsqueeze(1),
                    betas=pred_shape,
                    pose2rot=False)
                pred_vertices = pred_smpl_output.vertices
                pred_vertices2d = orthographic_project_torch(
                    pred_vertices, pred_cam_wp)
                pred_vertices2d = undo_keypoint_normalisation(
                    pred_vertices2d, proxy_rep_input_wh)

                pred_reposed_smpl_output = smpl(betas=pred_shape)
                pred_reposed_vertices = pred_reposed_smpl_output.vertices

            # Numpy-fying
            pred_vertices = pred_vertices.cpu().detach().numpy()[0]
            pred_vertices2d = pred_vertices2d.cpu().detach().numpy()[0]
            pred_reposed_vertices = pred_reposed_vertices.cpu().detach().numpy(
            )[0]
            pred_cam_wp = pred_cam_wp.cpu().detach().numpy()[0]

            if not os.path.isdir(os.path.join(input, 'verts_vis')):
                os.makedirs(os.path.join(input, 'verts_vis'))
            plt.figure()
            plt.imshow(image[:, :, ::-1])
            plt.scatter(pred_vertices2d[:, 0], pred_vertices2d[:, 1], s=0.3)
            plt.gca().set_axis_off()
            plt.subplots_adjust(top=1,
                                bottom=0,
                                right=1,
                                left=0,
                                hspace=0,
                                wspace=0)
            plt.margins(0, 0)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.savefig(os.path.join(input, 'verts_vis', 'verts_' + fname))

            if render_vis:
                rend_img = wp_renderer.render(verts=pred_vertices,
                                              cam=pred_cam_wp,
                                              img=image)
                rend_reposed_img = wp_renderer.render(
                    verts=pred_reposed_vertices,
                    cam=np.array([0.8, 0., -0.2]),
                    angle=180,
                    axis=[1, 0, 0])
                if not os.path.isdir(os.path.join(input, 'rend_vis')):
                    os.makedirs(os.path.join(input, 'rend_vis'))
                cv2.imwrite(os.path.join(input, 'rend_vis', 'rend_' + fname),
                            rend_img)
                cv2.imwrite(
                    os.path.join(input, 'rend_vis', 'reposed_' + fname),
                    rend_reposed_img)
            if save_proxy_vis:
                if not os.path.isdir(os.path.join(input, 'proxy_vis')):
                    os.makedirs(os.path.join(input, 'proxy_vis'))
                cv2.imwrite(
                    os.path.join(input, 'proxy_vis', 'silhouette_' + fname),
                    silhouette_vis)
                cv2.imwrite(
                    os.path.join(input, 'proxy_vis', 'joints2D_' + fname),
                    joints2D_vis)
                                         params_from='all')
val_dataset = SyntheticTrainingDataset(npz_path=val_path, params_from='all')
train_val_monitor_datasets = [train_dataset, val_dataset]
print("Training examples found:", len(train_dataset))
print("Validation examples found:", len(val_dataset))

# ----------------------- Models -----------------------
# Regressor
regressor = SingleInputRegressor(resnet_in_channels,
                                 resnet_layers,
                                 ief_iters=ief_iters)
num_params = count_parameters(regressor)
print("\nRegressor model Loaded. ", num_params, "trainable parameters.")

# SMPL model
smpl_model = SMPL(config.SMPL_MODEL_DIR, batch_size=batch_size)

# Camera and NMR part/silhouette renderer
# Assuming camera rotation is identity (since it is dealt with by global_orients in SMPL)
mean_cam_t = np.array([0., 0.2, 42.])
mean_cam_t = torch.from_numpy(mean_cam_t).float().to(device)
mean_cam_t = mean_cam_t[None, :].expand(batch_size, -1)
cam_K = get_intrinsics_matrix(config.REGRESSOR_IMG_WH, config.REGRESSOR_IMG_WH,
                              config.FOCAL_LENGTH)
cam_K = torch.from_numpy(cam_K.astype(np.float32)).to(device)
cam_K = cam_K[None, :, :].expand(batch_size, -1, -1)
cam_R = torch.eye(3).to(device)
cam_R = cam_R[None, :, :].expand(batch_size, -1, -1)
nmr_parts_renderer = NMRRenderer(batch_size,
                                 cam_K,
                                 cam_R,