示例#1
0
def add_hand_viz(ax, hand_df, joint_idxs=False, score_thresh=0.2):
    for hand_idx, hand_det in hand_df.iterrows():
        joints2d = hand_det.joints2d
        links = [
            (0, 1, 2, 3, 4),
            (0, 5, 6, 7, 8),
            (0, 9, 10, 11, 12),
            (0, 13, 14, 15, 16),
            (0, 17, 18, 19, 20),
        ]
        scores = hand_det.joints2d_scores
        if scores.shape[0] > 21:
            kept_links = []
            for finger_idxs in links:
                score_finger_links = []
                for link_idx in finger_idxs:
                    if scores[link_idx] > score_thresh:
                        score_finger_links.append(link_idx)
                kept_links.append(score_finger_links)
            joint_labels = [
                f"{scores[idx]:.2f}" for idx in range(len(joints2d))
            ]
        else:
            joint_labels = None
        viz2d.visualize_joints_2d(
            ax,
            joints2d,
            joint_labels=joint_labels,
            joint_idxs=joint_idxs,
            links=kept_links,
        )
示例#2
0
def get_hands(
    hoa_df,
    img_path,
    crop_size=256,
    img_resize_factor=1,
    hand_mode="opp",
    scale_factor=1.5,
    debug=True,
):
    hand_dicts = []
    for det_idx, det in hoa_df.iterrows():
        if det.det_type == "hand":
            det_bbox = [
                val * img_resize_factor
                for val in [det.left, det.top, det.right, det.bottom]
            ]
            square_det = preprocess.squarify(det_bbox, scale_factor)
            img = Image.open(img_path)
            crop = img.crop(square_det)
            crop_radius = square_det[2] - square_det[0]
            # Resize and convert to BGR
            hand_crop = cv2.resize(np.array(crop), (crop_size, crop_size))[:, :, ::-1]
            if det.side == "left":
                # Process all hands as right hands
                joints2d, peak_scores = hand_opp(np.flip(hand_crop, axis=1))
                joints2d = postprocess.flip_coords(
                    joints2d, crop_size=crop_size, axis=0
                )
            else:
                joints2d, peak_scores = hand_opp(hand_crop)
            mean_scores = np.stack(peak_scores).mean(0)[0]

            # joints2d_abs = joints2d / crop_size * crop_radius + np.array([det_bbox[0], det_bbox[1]])
            joints2d_abs = joints2d / crop_size * crop_radius + np.array(
                [square_det[0], square_det[1]]
            )
            if debug:
                fig, ax = plt.subplots(1)
                ax.imshow(img)
                # scatter(joints2d_abs[:, 0], joints2d_abs[:, 1], s=1)
                viz2d.visualize_joints_2d(ax, joints2d_abs, joint_idxs=False)
                fig.savefig("tmp.png")
                plt.clf()
                fig, ax = plt.subplots(1)
                ax.imshow(hand_crop[:, :, ::-1])
                viz2d.visualize_joints_2d(ax, joints2d, joint_idxs=False)
                fig.savefig("tmp_crop.png")
            det_dict = det.to_dict()
            det_dict["joints2d"] = joints2d_abs
            det_dict["joints2d_scores"] = mean_scores
            hand_dicts.append(det_dict)
    return pd.DataFrame(hand_dicts)
def sample_vis(sample, results, save_img_path, fig=None, max_rows=5, display_centered=False):
    fig.clf()
    images = sample[TransQueries.IMAGE].permute(0, 2, 3, 1).cpu() + 0.5
    batch_size = images.shape[0]
    # pred_handverts2d = get_check_none(results, "verts2d")
    gt_objverts2d = get_check_none(sample, TransQueries.OBJVERTS2D)
    pred_objverts2d = get_check_none(results, "obj_verts2d")
    gt_objcorners2d = get_check_none(sample, TransQueries.OBJCORNERS2D)
    pred_objcorners2d = get_check_none(results, "obj_corners2d")
    gt_objcorners3dw = get_check_none(sample, BaseQueries.OBJCORNERS3D)
    pred_objcorners3d = get_check_none(results, "obj_corners3d")
    gt_objverts3d = get_check_none(sample, TransQueries.OBJVERTS3D)
    pred_objverts3d = get_check_none(results, "obj_verts3d")
    gt_canobjverts3d = get_check_none(sample, TransQueries.OBJCANROTVERTS)
    pred_objverts3dw = get_check_none(results, "recov_objverts3d")
    gt_canobjcorners3d = get_check_none(sample, TransQueries.OBJCANROTCORNERS)
    pred_objcorners3dw = get_check_none(results, "recov_objcorners3d")
    gt_handjoints2d = get_check_none(sample, TransQueries.JOINTS2D)
    pred_handjoints2d = get_check_none(results, "joints2d")
    gt_handjoints3d = get_check_none(sample, TransQueries.JOINTS3D)
    pred_handjoints3d = get_check_none(results, "joints3d")
    gt_handverts3d = get_check_none(sample, TransQueries.HANDVERTS3D)
    pred_handverts3d = get_check_none(results, "verts3d")
    gt_objverts3dw = get_check_none(sample, BaseQueries.OBJVERTS3D)
    pred_handjoints3dw = get_check_none(results, "recov_joints3d")
    gt_handjoints3dw = get_check_none(sample, BaseQueries.JOINTS3D)
    pred_objfps2d = get_check_none(results, "kpt_2d")
    gt_objfps2d = get_check_none(sample, BaseQueries.OBJFPS2D)
    pred_objvar2d = get_check_none(results, "var")
    # gt_drill_angle_X, gt_drill_angle_Y, tmp = compute_gt_drill_angles(sample)

    row_nb = min(max_rows, batch_size)
    if display_centered:
        col_nb = 7
    else:
        col_nb = 5
    axes = fig.subplots(row_nb, col_nb)

    for row_idx in range(row_nb):
        # Column 0
        col_idx = 0
        axes[row_idx, col_idx].imshow(images[row_idx])
        axes[row_idx, col_idx].axis("off")
        # Visualize 2D hand joints
        if pred_handjoints2d is not None:
            visualize_joints_2d(axes[row_idx, col_idx], pred_handjoints2d[row_idx], alpha=1, joint_idxs=False)
        if gt_handjoints2d is not None:
            visualize_joints_2d(axes[row_idx, col_idx], gt_handjoints2d[row_idx], alpha=0.5, joint_idxs=False)

        # Column 1
        col_idx = 1
        axes[row_idx, col_idx].imshow(images[row_idx])
        axes[row_idx, col_idx].axis("off")
        # axes[row_idx, col_idx].set_title("dvec: {:.2f},{:.2f},{:.2f}".format(tmp[row_idx, 0, 0], tmp[row_idx, 0, 1], tmp[row_idx, 0, 2]))
        if gt_objfps2d is not None and pred_objfps2d is not None:
            arrow_nb = gt_objfps2d.shape[1]
            idxs = range(arrow_nb)
            arrows = torch.cat([gt_objfps2d[:, idxs].float(), pred_objfps2d[:, idxs].float()], 1)
            links = [[i, i + arrow_nb] for i in idxs]
            visualize_joints_2d(
                axes[row_idx, col_idx],
                arrows[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=links,
                color=["k"] * arrow_nb,
            )
        if pred_objvar2d is not None:
            ells = compute_confidence_ellipses(pred_objfps2d[row_idx], pred_objvar2d[row_idx])
            for e in ells:
                axes[row_idx, col_idx].add_artist(e)
        if pred_objfps2d is not None:
            axes[row_idx, col_idx].scatter(
                pred_objfps2d[row_idx, :, 0], pred_objfps2d[row_idx, :, 1], c="r", s=2, marker="X", alpha=0.7
            )
        if gt_objfps2d is not None:
            axes[row_idx, col_idx].scatter(
                gt_objfps2d[row_idx, :, 0], gt_objfps2d[row_idx, :, 1], c="b", s=2, marker="X", alpha=0.7
            )
        # if gt_objfps2d is not None:
        #     axes[row_idx, col_idx].scatter(
        #         gt_objfps2d[row_idx, [0,5], 0], gt_objfps2d[row_idx, [0,5], 1], c="g", s=4, marker="o", alpha=0.7
        #     )


        # Column 2
        col_idx = 2
        axes[row_idx, col_idx].imshow(images[row_idx])
        axes[row_idx, col_idx].axis("off")
        # Visualize 2D object vertices
        if pred_objverts2d is not None:
            axes[row_idx, col_idx].scatter(
                pred_objverts2d[row_idx, :, 0], pred_objverts2d[row_idx, :, 1], c="r", s=1, alpha=0.1
            )
        if gt_objverts2d is not None:
            axes[row_idx, col_idx].scatter(
                gt_objverts2d[row_idx, :, 0], gt_objverts2d[row_idx, :, 1], c="b", s=1, alpha=0.02
            )
        # Visualize 2D object bounding box
        # if pred_objcorners2d is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         pred_objcorners2d[row_idx],
        #         alpha=1,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # if gt_objcorners2d is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         gt_objcorners2d[row_idx],
        #         alpha=0.5,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # Visualize some (vertex position) errors for the 2D object vertices
        if (gt_objfps2d is None or pred_objfps2d is None) and gt_objverts2d is not None and pred_objverts2d is not None:
            idxs = list(range(6))
            arrow_nb = len(idxs)
            arrows = torch.cat([gt_objverts2d[:, idxs].float(), pred_objverts2d[:, idxs].float()], 1)
            links = [[i, i + arrow_nb] for i in range(arrow_nb)]
            visualize_joints_2d(
                axes[row_idx, col_idx],
                arrows[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=links,
                color=["k"] * arrow_nb,
            )


        # Column 3
        # view from the top
        col_idx = 3
        #axes[row_idx, col_idx].set_title("rotY: {:.1f}".format(gt_drill_angle_Y[row_idx]))
        if gt_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                gt_objverts3dw[row_idx, :, 2], gt_objverts3dw[row_idx, :, 0], c="b", s=1, alpha=0.02
            )
        if pred_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                pred_objverts3dw[row_idx, :, 2], pred_objverts3dw[row_idx, :, 0], c="r", s=1, alpha=0.02
            )

        if pred_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], pred_handjoints3dw[row_idx, :, [2,0]], alpha=1, joint_idxs=False
            )
        if gt_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], gt_handjoints3dw[row_idx, :, [2,0]], alpha=0.5, joint_idxs=False
            )
        axes[row_idx, col_idx].invert_yaxis()

        # if pred_objcorners3dw is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         pred_objcorners3dw[row_idx],
        #         alpha=1,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # if gt_objcorners3dw is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         gt_objcorners3dw[row_idx],
        #         alpha=0.5,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # if pred_objverts3dw is not None and gt_objverts3dw is not None:
        #     arrow_nb = 6
        #     arrows = torch.cat([gt_objverts3dw[:, :arrow_nb], pred_objverts3dw[:, :arrow_nb]], 1)
        #     links = [[i, i + arrow_nb] for i in range(arrow_nb)]
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         arrows[row_idx],
        #         alpha=0.5,
        #         joint_idxs=False,
        #         links=links,
        #         color=["k"] * arrow_nb,
        #     )

        # Column 4
        # view from the right
        col_idx = 4
        #axes[row_idx, col_idx].set_title("rotX: {:.1f}".format(gt_drill_angle_X[row_idx]))
        # invert second axis here for more consistent viewpoints
        if gt_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                gt_objverts3dw[row_idx, :, 2], -gt_objverts3dw[row_idx, :, 1], c="b", s=1, alpha=0.02
            )
        if pred_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                pred_objverts3dw[row_idx, :, 2], -pred_objverts3dw[row_idx, :, 1], c="r", s=1, alpha=0.02
            )
        if pred_handjoints3dw is not None:
            pred_handjoints3dw_inv = np.stack([pred_handjoints3dw[:, :, 2], -pred_handjoints3dw[:, :, 1]], axis=-1)
            visualize_joints_2d(
                axes[row_idx, col_idx], pred_handjoints3dw_inv[row_idx, :, :], alpha=1, joint_idxs=False
            )
        if gt_handjoints3dw is not None:
            gt_handjoints3dw_inv = np.stack([gt_handjoints3dw[:, :, 2], -gt_handjoints3dw[:, :, 1]], axis=-1)
            visualize_joints_2d(
                axes[row_idx, col_idx], gt_handjoints3dw_inv[row_idx, :, :], alpha=0.5, joint_idxs=False
            )
        # if pred_objcorners3dw is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         pred_objcorners3dw[row_idx, :, 1:],
        #         alpha=1,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # if gt_objcorners3dw is not None:
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         gt_objcorners3dw[row_idx, :, 1:],
        #         alpha=0.5,
        #         joint_idxs=False,
        #         links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
        #     )
        # if pred_objverts3dw is not None and gt_objverts3dw is not None:
        #     arrow_nb = 6
        #     arrows = torch.cat([gt_objverts3dw[:, :arrow_nb, 1:], pred_objverts3dw[:, :arrow_nb, 1:]], 1)
        #     links = [[i, i + arrow_nb] for i in range(arrow_nb)]
        #     visualize_joints_2d(
        #         axes[row_idx, col_idx],
        #         arrows[row_idx],
        #         alpha=0.5,
        #         joint_idxs=False,
        #         links=links,
        #         color=["k"] * arrow_nb,
        #     )

        if display_centered:
            # Column 5
            col_idx = 5
            if gt_canobjverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_canobjverts3d[row_idx, :, 0], gt_canobjverts3d[row_idx, :, 1], c="b", s=1, alpha=0.02
                )
            if pred_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_objverts3d[row_idx, :, 0], pred_objverts3d[row_idx, :, 1], c="r", s=1, alpha=0.02
                )
            if pred_objcorners3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    pred_objcorners3d[row_idx],
                    alpha=1,
                    joint_idxs=False,
                    links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
                )
            if gt_canobjcorners3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    gt_canobjcorners3d[row_idx],
                    alpha=0.5,
                    joint_idxs=False,
                    links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
                )
            if pred_objcorners3d is not None and gt_canobjcorners3d is not None:
                arrow_nb = 6
                arrows = torch.cat([gt_canobjcorners3d[:, :arrow_nb], pred_objcorners3d[:, :arrow_nb]], 1)
                links = [[i, i + arrow_nb] for i in range(arrow_nb)]
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    arrows[row_idx],
                    alpha=0.5,
                    joint_idxs=False,
                    links=links,
                    color=["k"] * arrow_nb,
                )
            axes[row_idx, col_idx].set_aspect("equal")
            axes[row_idx, col_idx].invert_yaxis()

            # Column 6
            col_idx = 6
            if gt_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_objverts3d[row_idx, :, 0], gt_objverts3d[row_idx, :, 1], c="b", s=1, alpha=0.02
                )
            # if pred_objverts3d is not None:
            #     axes[row_idx, 2].scatter(
            #         pred_objverts3d[row_idx, :, 0], pred_objverts3d[row_idx, :, 1], c="r", s=1, alpha=0.02
            #     )
            if gt_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_handverts3d[row_idx, :, 0], gt_handverts3d[row_idx, :, 1], c="g", s=1, alpha=0.2
                )
            if pred_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_handverts3d[row_idx, :, 0], pred_handverts3d[row_idx, :, 1], c="c", s=1, alpha=0.2
                )
            if pred_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], pred_handjoints3d[row_idx], alpha=1, joint_idxs=False
                )
            if gt_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], gt_handjoints3d[row_idx], alpha=0.5, joint_idxs=False
                )
            axes[row_idx, col_idx].invert_yaxis()

            # Column 7
            col_idx = 7
            if gt_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_objverts3d[row_idx, :, 1], gt_objverts3d[row_idx, :, 2], c="b", s=1, alpha=0.02
                )
            # if pred_objverts3d is not None:
            #     axes[row_idx, 3].scatter(
            #         pred_objverts3d[row_idx, :, 1], pred_objverts3d[row_idx, :, 2], c="r", s=1, alpha=0.02
            #     )
            if gt_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_handverts3d[row_idx, :, 1], gt_handverts3d[row_idx, :, 2], c="g", s=1, alpha=0.2
                )
            if pred_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_handverts3d[row_idx, :, 1], pred_handverts3d[row_idx, :, 2], c="c", s=1, alpha=0.2
                )
            if pred_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], pred_handjoints3d[row_idx][:, 1:], alpha=1, joint_idxs=False
                )
            if gt_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], gt_handjoints3d[row_idx][:, 1:], alpha=0.5, joint_idxs=False
                )

    consistdisplay.squashfig(fig)
    fig.savefig(save_img_path, dpi=300)
def sample_vis(sample, results, save_img_path, fig=None, max_rows=5, display_centered=False):
    fig.clf()
    images = sample[TransQueries.IMAGE].permute(0, 2, 3, 1).cpu() + 0.5
    batch_size = images.shape[0]
    # pred_handverts2d = get_check_none(results, "verts2d")
    gt_objverts2d = get_check_none(sample, TransQueries.OBJVERTS2D)
    pred_objverts2d = get_check_none(results, "obj_verts2d")
    gt_objcorners2d = get_check_none(sample, TransQueries.OBJCORNERS2D)
    pred_objcorners2d = get_check_none(results, "obj_corners2d")
    gt_objcorners3dw = get_check_none(sample, BaseQueries.OBJCORNERS3D)
    pred_objcorners3d = get_check_none(results, "obj_corners3d")
    gt_objverts3d = get_check_none(sample, TransQueries.OBJVERTS3D)
    gt_canobjverts3d = get_check_none(sample, TransQueries.OBJCANROTVERTS)
    gt_canobjcorners3d = get_check_none(sample, TransQueries.OBJCANROTCORNERS)
    pred_objverts3d = get_check_none(results, "obj_verts3d")
    gt_handjoints2d = get_check_none(sample, TransQueries.JOINTS2D)
    pred_handjoints2d = get_check_none(results, "joints2d")
    gt_handjoints3d = get_check_none(sample, TransQueries.JOINTS3D)
    pred_handjoints3d = get_check_none(results, "joints3d")
    gt_handverts3d = get_check_none(sample, TransQueries.HANDVERTS3D)
    gt_objverts3dw = get_check_none(sample, BaseQueries.OBJVERTS3D)
    pred_handjoints3dw = get_check_none(results, "recov_joints3d")
    gt_handjoints3dw = get_check_none(sample, BaseQueries.JOINTS3D)
    pred_objverts3dw = get_check_none(results, "recov_objverts3d")
    pred_objcorners3dw = get_check_none(results, "recov_objcorners3d")
    pred_handverts3d = get_check_none(results, "verts3d")
    row_nb = min(max_rows, batch_size)
    if display_centered:
        col_nb = 7
    else:
        col_nb = 4
    axes = fig.subplots(row_nb, col_nb)
    for row_idx in range(row_nb):
        # Column 0
        axes[row_idx, 0].imshow(images[row_idx])
        axes[row_idx, 0].axis("off")
        if pred_handjoints2d is not None:
            visualize_joints_2d(axes[row_idx, 0], pred_handjoints2d[row_idx], alpha=1, joint_idxs=False)
        if gt_handjoints2d is not None:
            visualize_joints_2d(axes[row_idx, 0], gt_handjoints2d[row_idx], alpha=0.5, joint_idxs=False)

        # Column 1
        axes[row_idx, 1].imshow(images[row_idx])
        axes[row_idx, 1].axis("off")
        if pred_objverts2d is not None:
            axes[row_idx, 1].scatter(
                pred_objverts2d[row_idx, :, 0], pred_objverts2d[row_idx, :, 1], c="r", s=1, alpha=0.2
            )
        if gt_objverts2d is not None:
            axes[row_idx, 1].scatter(
                gt_objverts2d[row_idx, :, 0], gt_objverts2d[row_idx, :, 1], c="b", s=1, alpha=0.02
            )
        if pred_objcorners2d is not None:
            visualize_joints_2d(
                axes[row_idx, 1],
                pred_objcorners2d[row_idx],
                alpha=1,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if gt_objcorners2d is not None:
            visualize_joints_2d(
                axes[row_idx, 1],
                gt_objcorners2d[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if gt_objverts2d is not None and pred_objverts2d is not None:
            idxs = list(range(6))
            arrow_nb = len(idxs)
            arrows = torch.cat([gt_objverts2d[:, idxs].float(), pred_objverts2d[:, idxs].float()], 1)
            links = [[i, i + arrow_nb] for i in range(arrow_nb)]
            visualize_joints_2d(
                axes[row_idx, 1],
                arrows[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=links,
                color=["k"] * arrow_nb,
            )
        # Column 2
        col_idx = 2
        if gt_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                gt_objverts3dw[row_idx, :, 0], gt_objverts3dw[row_idx, :, 1], c="b", s=1, alpha=0.02
            )
        if pred_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                pred_objverts3dw[row_idx, :, 0], pred_objverts3dw[row_idx, :, 1], c="r", s=1, alpha=0.02
            )
        if pred_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], pred_handjoints3dw[row_idx][:], alpha=1, joint_idxs=False
            )
        if gt_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], gt_handjoints3dw[row_idx][:], alpha=0.5, joint_idxs=False
            )
        axes[row_idx, col_idx].invert_yaxis()

        if pred_objcorners3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx],
                pred_objcorners3dw[row_idx],
                alpha=1,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if gt_objcorners3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx],
                gt_objcorners3dw[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if pred_objverts3dw is not None and gt_objverts3dw is not None:
            arrow_nb = 6
            arrows = torch.cat([gt_objverts3dw[:, :arrow_nb], pred_objverts3dw[:, :arrow_nb]], 1)
            links = [[i, i + arrow_nb] for i in range(arrow_nb)]
            visualize_joints_2d(
                axes[row_idx, col_idx],
                arrows[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=links,
                color=["k"] * arrow_nb,
            )
        # Column 3
        col_idx = 3
        if gt_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                gt_objverts3dw[row_idx, :, 1], gt_objverts3dw[row_idx, :, 2], c="b", s=1, alpha=0.02
            )
        if pred_objverts3dw is not None:
            axes[row_idx, col_idx].scatter(
                pred_objverts3dw[row_idx, :, 1], pred_objverts3dw[row_idx, :, 2], c="r", s=1, alpha=0.02
            )
        if pred_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], pred_handjoints3dw[row_idx][:, 1:], alpha=1, joint_idxs=False
            )
        if gt_handjoints3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx], gt_handjoints3dw[row_idx][:, 1:], alpha=0.5, joint_idxs=False
            )
        if pred_objcorners3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx],
                pred_objcorners3dw[row_idx, :, 1:],
                alpha=1,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if gt_objcorners3dw is not None:
            visualize_joints_2d(
                axes[row_idx, col_idx],
                gt_objcorners3dw[row_idx, :, 1:],
                alpha=0.5,
                joint_idxs=False,
                links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
            )
        if pred_objverts3dw is not None and gt_objverts3dw is not None:
            arrow_nb = 6
            arrows = torch.cat([gt_objverts3dw[:, :arrow_nb, 1:], pred_objverts3dw[:, :arrow_nb, 1:]], 1)
            links = [[i, i + arrow_nb] for i in range(arrow_nb)]
            visualize_joints_2d(
                axes[row_idx, col_idx],
                arrows[row_idx],
                alpha=0.5,
                joint_idxs=False,
                links=links,
                color=["k"] * arrow_nb,
            )

        if display_centered:
            # Column 4
            col_idx = 4
            if gt_canobjverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_canobjverts3d[row_idx, :, 0], gt_canobjverts3d[row_idx, :, 1], c="b", s=1, alpha=0.02
                )
            if pred_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_objverts3d[row_idx, :, 0], pred_objverts3d[row_idx, :, 1], c="r", s=1, alpha=0.02
                )
            if pred_objcorners3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    pred_objcorners3d[row_idx],
                    alpha=1,
                    joint_idxs=False,
                    links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
                )
            if gt_canobjcorners3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    gt_canobjcorners3d[row_idx],
                    alpha=0.5,
                    joint_idxs=False,
                    links=[[0, 1, 3, 2], [4, 5, 7, 6], [1, 5], [3, 7], [4, 0], [0, 2, 6, 4]],
                )
            if pred_objcorners3d is not None and gt_canobjcorners3d is not None:
                arrow_nb = 6
                arrows = torch.cat([gt_canobjcorners3d[:, :arrow_nb], pred_objcorners3d[:, :arrow_nb]], 1)
                links = [[i, i + arrow_nb] for i in range(arrow_nb)]
                visualize_joints_2d(
                    axes[row_idx, col_idx],
                    arrows[row_idx],
                    alpha=0.5,
                    joint_idxs=False,
                    links=links,
                    color=["k"] * arrow_nb,
                )
            axes[row_idx, col_idx].set_aspect("equal")
            axes[row_idx, col_idx].invert_yaxis()

            # Column 5
            col_idx = 5
            if gt_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_objverts3d[row_idx, :, 0], gt_objverts3d[row_idx, :, 1], c="b", s=1, alpha=0.02
                )
            # if pred_objverts3d is not None:
            #     axes[row_idx, 2].scatter(
            #         pred_objverts3d[row_idx, :, 0], pred_objverts3d[row_idx, :, 1], c="r", s=1, alpha=0.02
            #     )
            if gt_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_handverts3d[row_idx, :, 0], gt_handverts3d[row_idx, :, 1], c="g", s=1, alpha=0.2
                )
            if pred_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_handverts3d[row_idx, :, 0], pred_handverts3d[row_idx, :, 1], c="c", s=1, alpha=0.2
                )
            if pred_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], pred_handjoints3d[row_idx], alpha=1, joint_idxs=False
                )
            if gt_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], gt_handjoints3d[row_idx], alpha=0.5, joint_idxs=False
                )
            axes[row_idx, col_idx].invert_yaxis()

            # Column 6
            col_idx = 6
            if gt_objverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_objverts3d[row_idx, :, 1], gt_objverts3d[row_idx, :, 2], c="b", s=1, alpha=0.02
                )
            # if pred_objverts3d is not None:
            #     axes[row_idx, 3].scatter(
            #         pred_objverts3d[row_idx, :, 1], pred_objverts3d[row_idx, :, 2], c="r", s=1, alpha=0.02
            #     )
            if gt_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    gt_handverts3d[row_idx, :, 1], gt_handverts3d[row_idx, :, 2], c="g", s=1, alpha=0.2
                )
            if pred_handverts3d is not None:
                axes[row_idx, col_idx].scatter(
                    pred_handverts3d[row_idx, :, 1], pred_handverts3d[row_idx, :, 2], c="c", s=1, alpha=0.2
                )
            if pred_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], pred_handjoints3d[row_idx][:, 1:], alpha=1, joint_idxs=False
                )
            if gt_handjoints3d is not None:
                visualize_joints_2d(
                    axes[row_idx, col_idx], gt_handjoints3d[row_idx][:, 1:], alpha=0.5, joint_idxs=False
                )

    consistdisplay.squashfig(fig)
    fig.savefig(save_img_path, dpi=300)
示例#5
0
                    frame = tarutils.cv2_imread_tar(tarf, frame_subpath)
                else:
                    video_folder = os.path.join(args.epic_root,
                                                video_full_id[:3],
                                                video_full_id)
                    frame = cv2.imread(os.path.join(video_folder, frame_name))
                if not args.without_image:
                    ax.imshow(frame[:, :, ::-1])
                else:
                    ax.imshow(np.zeros_like(frame[:, :, ::-1]))
                ax.set_title(action_segms[frame_idx])
                ax.axis("off")

                box_annots = box_props[frame_idx]
                for hand in hands:
                    viz2d.visualize_joints_2d(ax, hand, joint_idxs=False)
                for bboxes, noun in box_annots:
                    bboxes_norm = [
                        epic_box_to_norm(bbox, resize_factor=resize_factor)
                        for bbox in bboxes
                    ]
                    label_color = "w"
                    detect2d.visualize_bboxes(
                        ax,
                        bboxes_norm,
                        labels=[
                            noun,
                        ] * len(bboxes),
                        label_color=label_color,
                    )
                if args.rename: