def visualize_volumes(images_batch,
                      volumes_batch,
                      proj_matricies_batch,
                      kind="cmu",
                      cuboids_batch=None,
                      batch_index=0,
                      size=5,
                      max_n_rows=10,
                      max_n_cols=10):
    n_views, n_joints = volumes_batch.shape[1], volumes_batch.shape[2]

    n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows)
    fig = plt.figure(figsize=(n_cols * size, n_rows * size))

    # images
    images = image_batch_to_numpy(images_batch[batch_index])
    images = denormalize_image(images).astype(np.uint8)
    images = images[..., ::-1]  # bgr ->

    # heatmaps
    volumes = to_numpy(volumes_batch[batch_index])

    for row in range(n_rows):
        for col in range(n_cols):
            if col == 0:
                ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)
                ax.set_ylabel(str(row), size='large')

                cuboid = cuboids_batch[batch_index]
                ax.imshow(
                    cuboid.render(
                        proj_matricies_batch[batch_index,
                                             row].detach().cpu().numpy(),
                        images[row].copy()))
            else:
                ax = fig.add_subplot(n_rows,
                                     n_cols,
                                     row * n_cols + col + 1,
                                     projection='3d')

                if row == 0:
                    joint_name = JOINT_NAMES_DICT[kind][
                        col - 1] if kind in JOINT_NAMES_DICT else str(col - 1)
                    ax.set_title(joint_name)

                draw_voxels(volumes[col - 1], ax, norm=True)

    fig.tight_layout()

    fig_image = fig_to_array(fig)

    plt.close('all')

    return fig_image
def visualize_heatmaps(images_batch,
                       heatmaps_batch,
                       kind="cmu",
                       batch_index=0,
                       size=5,
                       max_n_rows=10,
                       max_n_cols=10):
    n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2]
    heatmap_shape = heatmaps_batch.shape[3:]

    n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows)
    fig, axes = plt.subplots(ncols=n_cols,
                             nrows=n_rows,
                             figsize=(n_cols * size, n_rows * size))
    axes = axes.reshape(n_rows, n_cols)

    # images
    images = image_batch_to_numpy(images_batch[batch_index])
    images = denormalize_image(images).astype(np.uint8)
    images = images[..., ::-1]  # bgr ->

    # heatmaps
    heatmaps = to_numpy(heatmaps_batch[batch_index])

    for row in range(n_rows):
        for col in range(n_cols):
            if col == 0:
                axes[row, col].set_ylabel(str(row), size='large')
                axes[row, col].imshow(images[row])
            else:
                if row == 0:
                    joint_name = JOINT_NAMES_DICT[kind][
                        col - 1] if kind in JOINT_NAMES_DICT else str(col - 1)
                    axes[row, col].set_title(joint_name)

                axes[row, col].imshow(resize_image(images[row], heatmap_shape))
                axes[row, col].imshow(heatmaps[row, col - 1], alpha=0.5)

    fig.tight_layout()

    fig_image = fig_to_array(fig)

    plt.close('all')

    return fig_image
def visualize_batch(images_batch,
                    heatmaps_batch,
                    keypoints_2d_batch,
                    proj_matricies_batch,
                    keypoints_3d_batch_gt,
                    keypoints_3d_batch_pred,
                    kind="cmu",
                    cuboids_batch=None,
                    confidences_batch=None,
                    batch_index=0,
                    size=5,
                    max_n_cols=10,
                    pred_kind=None):
    if pred_kind is None:
        pred_kind = kind

    n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2]

    n_rows = 3
    n_rows = n_rows + 1 if keypoints_2d_batch is not None else n_rows
    n_rows = n_rows + 1 if cuboids_batch is not None else n_rows
    n_rows = n_rows + 1 if confidences_batch is not None else n_rows

    n_cols = min(n_views, max_n_cols)
    fig, axes = plt.subplots(ncols=n_cols,
                             nrows=n_rows,
                             figsize=(n_cols * size, n_rows * size))
    axes = axes.reshape(n_rows, n_cols)

    image_shape = images_batch.shape[3:]
    heatmap_shape = heatmaps_batch.shape[3:]

    row_i = 0

    # images
    axes[row_i, 0].set_ylabel("image", size='large')

    images = image_batch_to_numpy(images_batch[batch_index])
    images = denormalize_image(images).astype(np.uint8)
    images = images[..., ::-1]  # bgr -> rgb

    for view_i in range(n_cols):
        axes[row_i][view_i].imshow(images[view_i])
    row_i += 1

    # 2D keypoints (pred)
    if keypoints_2d_batch is not None:
        axes[row_i, 0].set_ylabel("2d keypoints (pred)", size='large')

        keypoints_2d = to_numpy(keypoints_2d_batch)[batch_index]
        for view_i in range(n_cols):
            axes[row_i][view_i].imshow(images[view_i])
            draw_2d_pose(keypoints_2d[view_i], axes[row_i][view_i], kind=kind)
        row_i += 1

    # 2D keypoints (gt projected)
    axes[row_i, 0].set_ylabel("2d keypoints (gt projected)", size='large')

    for view_i in range(n_cols):
        axes[row_i][view_i].imshow(images[view_i])
        keypoints_2d_gt_proj = project_3d_points_to_image_plane_without_distortion(
            proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(),
            keypoints_3d_batch_gt[batch_index].detach().cpu().numpy())
        draw_2d_pose(keypoints_2d_gt_proj, axes[row_i][view_i], kind=kind)
    row_i += 1

    # 2D keypoints (pred projected)
    axes[row_i, 0].set_ylabel("2d keypoints (pred projected)", size='large')

    for view_i in range(n_cols):
        axes[row_i][view_i].imshow(images[view_i])
        keypoints_2d_pred_proj = project_3d_points_to_image_plane_without_distortion(
            proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(),
            keypoints_3d_batch_pred[batch_index].detach().cpu().numpy())
        draw_2d_pose(keypoints_2d_pred_proj,
                     axes[row_i][view_i],
                     kind=pred_kind)
    row_i += 1

    # cuboids
    if cuboids_batch is not None:
        axes[row_i, 0].set_ylabel("cuboid", size='large')

        for view_i in range(n_cols):
            cuboid = cuboids_batch[batch_index]
            axes[row_i][view_i].imshow(
                cuboid.render(
                    proj_matricies_batch[batch_index,
                                         view_i].detach().cpu().numpy(),
                    images[view_i].copy()))
        row_i += 1

    # confidences
    if confidences_batch is not None:
        axes[row_i, 0].set_ylabel("confidences", size='large')

        for view_i in range(n_cols):
            confidences = to_numpy(confidences_batch[batch_index, view_i])
            xs = np.arange(len(confidences))

            axes[row_i, view_i].bar(xs, confidences, color='green')
            axes[row_i, view_i].set_xticks(xs)
            if torch.max(confidences_batch).item() <= 1.0:
                axes[row_i, view_i].set_ylim(0.0, 1.0)

    fig.tight_layout()

    fig_image = fig_to_array(fig)

    plt.close('all')

    return fig_image