示例#1
0
def add_hands_viz(ax, img, pred_hands, camintr, faces_per_pixel=2):
    # Create batched inputs
    camintr_th = torch.Tensor(camintr).unsqueeze(0).cuda()
    all_verts = [
        torch.Tensor(pred["verts"]).unsqueeze(0).cuda()
        for pred in pred_hands.values()
    ]
    all_faces = [
        torch.Tensor(pred["faces"].copy()).unsqueeze(0).cuda()
        for pred in pred_hands.values()
    ]
    verts, faces, _ = catmesh.batch_cat_meshes(all_verts, all_faces)
    # Convert vertices from weak perspective to camera
    unproj3d = perspective.unproject_points(verts[:, :, :2],
                                            verts[:, :, 2:] / 200 + 0.5,
                                            camintr_th)

    # Render
    res = py3drendutils.batch_render(
        unproj3d,
        faces,
        faces_per_pixel=faces_per_pixel,
        K=camintr_th,
        image_sizes=[(img.shape[1], img.shape[0])],
    )
    ax.imshow(res[0].detach().cpu()[:, :, :4], alpha=0.6)
示例#2
0
def fitobj2mask(
    masks,
    bboxes,
    obj_paths,
    z_off=0.5,
    radius=0.1,
    faces_per_pixel=1,
    lr=0.01,
    loss_type="l2",
    iters=100,
    viz_step=1,
    save_folder="tmp/",
    viz_rows=12,
    crop_box=True,
    crop_size=(200, 200),
    rot_nb=1,
):
    # Initialize logging info
    opts = {
        "z_off": z_off,
        "loss_type": loss_type,
        "iters": iters,
        "radius": radius,
        "lr": lr,
        "obj_paths": obj_paths,
        "faces_per_pix": faces_per_pixel,
    }
    results = {"opts": opts}
    save_folder = Path(save_folder)
    print(f"Saving to {save_folder}")
    metrics = defaultdict(list)

    batch_size = len(obj_paths)
    # Load normalized object
    batch_faces = []
    batch_verts = []
    for obj_path in obj_paths:
        verts_loc, faces_idx, _ = py3dload_obj(obj_path)
        faces = faces_idx.verts_idx
        batch_faces.append(faces.cuda())

        verts = normalize.normalize_verts(verts_loc, radius).cuda()
        batch_verts.append(verts)
    batch_verts = torch.stack(batch_verts)
    batch_faces = torch.stack(batch_faces)

    # Dummy intrinsic camera
    height, width = masks[0].shape
    focal = min(masks[0].shape)
    camintr = (
        torch.Tensor(
            [[focal, 0, width // 2], [0, focal, height // 2], [0, 0, 1]]
        )
        .cuda()
        .unsqueeze(0)
        .repeat(batch_size, 1, 1)
    )

    if crop_box:
        adaptive_loss = AdaptiveLossFunction(
            num_dims=crop_size[0] * crop_size[1],
            float_dtype=np.float32,
            device="cuda:0",
        )
    else:
        adaptive_loss = AdaptiveLossFunction(
            num_dims=height * width, float_dtype=np.float32, device="cuda:0"
        )
    # Prepare rigid parameters
    if rot_nb > 1:
        rot_mats = [special_ortho_group.rvs(3) for _ in range(rot_nb)]
        rot_vecs = torch.Tensor(
            [np.linalg.svd(rot_mat)[0][:2].reshape(-1) for rot_mat in rot_mats]
        )
        rot_vec = rot_vecs.repeat(batch_size, 1).cuda()
        # Ordering b1 rot1, b1 rot2, ..., b2 rot1, ...
    else:
        rot_vec = torch.Tensor(
            [[1, 0, 0, 0, 1, 0] for _ in range(batch_size)]
        ).cuda()

    bboxes_tight = torch.stack(bboxes)
    # trans = ops3d.trans_init_from_boxes(bboxes, camintr, (z_off, z_off)).cuda()
    trans = ops3d.trans_init_from_boxes_autodepth(
        bboxes_tight, camintr, batch_verts, z_guess=z_off
    ).cuda()
    # Repeat to match rots
    trans = repeatdim(trans, rot_nb, 1)
    bboxes = boxutils.preprocess_boxes(bboxes_tight, padding=10, squarify=True)
    if crop_box:
        camintr_crop = camutils.get_K_crop_resize(camintr, bboxes, crop_size)
    camintr_crop = repeatdim(camintr_crop, rot_nb, 1)

    trans.requires_grad = True
    rot_vec.requires_grad = True
    optim_params = [rot_vec, trans]
    if "adapt" in loss_type:
        optim_params = optim_params + list(adaptive_loss.parameters())
    optimizer = torch.optim.Adam([rot_vec, trans], lr=lr)

    ref_masks = torch.stack(masks).cuda()
    if crop_box:
        ref_masks = cropping.crops(ref_masks.float(), bboxes, crop_size)[:, 0]

    # Prepare reference mask
    if "dtf" in loss_type:
        target_masks = torch.stack(
            [torch.Tensor(dtf.distance_transform(mask)) for mask in ref_masks]
        ).cuda()
    else:
        target_masks = ref_masks
    ref_masks = repeatdim(ref_masks, rot_nb, 1)
    target_masks = repeatdim(target_masks, rot_nb, 1)
    batch_verts = repeatdim(batch_verts, rot_nb, 1)
    batch_faces = repeatdim(batch_faces, rot_nb, 1)

    col_nb = 5
    fig_res = 1.5
    # Aggregate images
    clip_data = []
    for iter_idx in tqdm(range(iters)):
        rot_mat = rotations.compute_rotation_matrix_from_ortho6d(rot_vec)
        optim_verts = batch_verts.bmm(rot_mat) + trans.unsqueeze(1)
        if crop_box:
            rendres = batch_render(
                optim_verts,
                batch_faces,
                K=camintr_crop,
                image_sizes=[(crop_size[1], crop_size[0])],
                mode="silh",
                faces_per_pixel=faces_per_pixel,
            )
        else:
            rendres = batch_render(
                optim_verts,
                batch_faces,
                K=camintr,
                image_sizes=[(width, height)],
                mode="silh",
                faces_per_pixel=faces_per_pixel,
            )
        optim_masks = rendres[:, :, :, -1]
        mask_diff = ref_masks - optim_masks
        mask_l2 = (mask_diff ** 2).mean()
        mask_l1 = mask_diff.abs().mean()
        mask_iou = lyiou.batch_mask_iou(
            (optim_masks > 0), (ref_masks > 0)
        ).mean()
        metrics["l1"].append(mask_l1.item())
        metrics["l2"].append(mask_l2.item())
        metrics["mask"].append(mask_iou.item())

        optim_mask_diff = target_masks - optim_masks
        if "l2" in loss_type:
            loss = (optim_mask_diff ** 2).mean()
        elif "l1" in loss_type:
            loss = optim_mask_diff.abs().mean()
        elif "adapt" in loss_type:
            loss = adaptive_loss.lossfun(
                optim_mask_diff.view(rot_nb * batch_size, -1)
            ).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iter_idx % viz_step == 0:
            row_idxs = np.linspace(
                0, batch_size * rot_nb - 1, viz_rows
            ).astype(np.int)
            row_nb = viz_rows
            fig, axes = plt.subplots(
                row_nb,
                col_nb,
                figsize=(int(col_nb * fig_res), int(row_nb * fig_res)),
            )
            for row_idx in range(row_nb):
                show_idx = row_idxs[row_idx]
                ax = vizmp.get_axis(
                    axes, row_idx, 0, row_nb=row_nb, col_nb=col_nb
                )
                ax.imshow(npt.numpify(optim_masks[show_idx]))
                ax.set_title("optim mask")
                ax = vizmp.get_axis(
                    axes, row_idx, 1, row_nb=row_nb, col_nb=col_nb
                )
                ax.imshow(npt.numpify(ref_masks[show_idx]))
                ax.set_title("ref mask")
                ax = vizmp.get_axis(
                    axes, row_idx, 2, row_nb=row_nb, col_nb=col_nb
                )
                ax.imshow(
                    npt.numpify(ref_masks[show_idx] - optim_masks[show_idx]),
                    vmin=-1,
                    vmax=1,
                )
                ax.set_title("ref masks diff")
                ax = vizmp.get_axis(
                    axes, row_idx, 3, row_nb=row_nb, col_nb=col_nb
                )
                ax.imshow(npt.numpify(target_masks[show_idx]), vmin=-1, vmax=1)
                ax.set_title("target mask")
                ax = vizmp.get_axis(
                    axes, row_idx, 4, row_nb=row_nb, col_nb=col_nb
                )
                ax.imshow(
                    npt.numpify(
                        target_masks[show_idx] - optim_masks[show_idx]
                    ),
                    vmin=-1,
                    vmax=1,
                )
                ax.set_title("masks diff")
            viz_folder = save_folder / "viz"
            viz_folder.mkdir(parents=True, exist_ok=True)
            data = vizmp.fig2np(fig)
            clip_data.append(data)
            fig.savefig(viz_folder / f"{iter_idx:04d}.png")

    clip = mpy.ImageSequenceClip(clip_data, fps=4)
    clip.write_videofile(str(viz_folder / "out.mp4"))
    clip.write_videofile(str(viz_folder / "out.webm"))
    results["metrics"] = metrics
    return results
示例#3
0
def show_fits(
        axes,
        frame,
        joints2d=None,
        vertices=None,
        arm_faces=None,
        cam_rot=None,
        cam_trans=None,
        camintr_th=None,
        img_shape=None,
        offset=(0, 0, 0.8),
):
    opp_links = [
        [0, 1, 8, 9, 10, 11, 22, 23],
        [4, 3, 2, 1, 5, 6, 7],
        [8, 12, 13, 14, 19, 20],
    ]
    hand_links = [
        [0, 1, 2, 3, 4],
        [0, 5, 6, 7, 8],
        [0, 9, 10, 11, 12],
        [0, 13, 14, 15, 16],
        [0, 17, 18, 19, 20],
    ]
    for hand_off in [25, 46]:
        for finger_link in hand_links:
            off_finger_link = [hand_off + joint for joint in finger_link]
            opp_links += [off_finger_link]
    ax = axes[0, 0]
    ax.imshow(frame[:, :, ::-1])
    ax.axis("off")

    rendered = batch_render(
        vertices.cuda(),
        arm_faces.int(),
        img_shape,
        camintrs=camintr_th,
        cam_rot=cam_rot,
        cam_trans=cam_trans,
    )
    ax = axes[0, 1]
    ax.axis("off")
    ax.imshow(frame[:, :, ::-1])
    ax.imshow(rendered[0].detach().cpu(), alpha=0.9)

    # Offset
    rendered_mirr = batch_render(
        (vertices + vertices.new(offset)).cuda(),
        arm_faces.int(),
        img_shape,
        camintrs=camintr_th,
        cam_rot=cam_rot,
        cam_trans=cam_trans,
    )
    ax = axes[1, 0]
    ax.axis("off")
    ax.imshow(torch.zeros_like(rendered_mirr[0, :, :, :3].detach().cpu()),
              alpha=0.5)
    ax.imshow(rendered_mirr[0].detach().cpu().numpy()[:, ::-1], alpha=0.5)
    # Offset mirrored
    ax = axes[1, 1]
    ax.axis("off")
    ax.imshow(torch.zeros_like(rendered[0, :, :, :3].detach().cpu()),
              alpha=0.5)
    ax.imshow(rendered[0].detach().cpu().numpy(), alpha=0.5)
示例#4
0
    def forward(self, faces_per_pixel=10, viz_views=True, min_depth=0.001):
        body_info = self.egohuman.forward()
        obj_infos = [obj.forward() for obj in self.objects]
        body_verts = body_info["verts"]
        obj_verts = [obj_info["verts"] for obj_info in obj_infos]
        verts = [body_verts] + obj_verts
        faces = [body_info["faces"]
                 ] + [obj_info["faces"] for obj_info in obj_infos]
        verts2d = self.camera.project(body_info["verts"])
        body_info["verts2d"] = verts2d

        origin_camintr = self.camera.camintr.to(verts2d.device).unsqueeze(0)
        bboxes = self.roi_bboxes.to(origin_camintr.device)
        camintr = camutils.get_K_crop_resize(
            origin_camintr.repeat(self.batch_size, 1, 1),
            bboxes,
            self.render_size,
            invert_xy=True,
        )
        rot = self.camera.rot.to(verts2d.device).unsqueeze(0)
        trans = self.camera.trans.to(verts2d.device).unsqueeze(0)
        height, width = self.camera.image_size

        body_color = ((0.25, 0.73, 1), )  # light_blue
        # obj_color = (0.74117647, 0.85882353, 0.65098039),  # light_green
        obj_color = ((0.25, 0.85, 0.85), )  # light_blue_green
        viz_colors = [get_colors(body_verts, body_color)] + [
            get_colors(obj_vert, obj_color) for obj_vert in obj_verts
        ]
        all_verts, all_faces, all_viz_colors = catmesh.batch_cat_meshes(
            verts, faces, viz_colors)
        face_colors = get_segm_colors(verts, faces)
        rendres = batch_render(
            all_verts,
            all_faces,
            K=camintr,
            rot=rot,
            trans=trans,
            image_sizes=[self.render_size],
            mode="facecolor",
            shading="soft",
            face_colors=face_colors,
            faces_per_pixel=faces_per_pixel,
            blend_gamma=self.blend_gamma,
            min_depth=min_depth,
        )
        scene_res = {
            "body_info": body_info,
            "obj_infos": obj_infos,
            "segm_rend": rendres,
        }
        if viz_views:
            with torch.no_grad():
                # Render scene in camera view
                cam_rendres = batch_render(
                    all_verts,
                    all_faces,
                    colors=all_viz_colors,
                    K=origin_camintr,
                    rot=rot,
                    trans=trans,
                    image_sizes=[(width, height)],
                    mode="rgb",
                    faces_per_pixel=2,
                    min_depth=min_depth,
                ).cpu()
                viz_verts = all_verts.clone()
                viz_verts[:, :, 2] = -viz_verts[:, :, 2]
                # Render front view
                front_rendres = batch_render(
                    all_verts.new([0, 0, 0.7]) + viz_verts,
                    torch.flip(all_faces, (2, )),  # Compensate for - in verts
                    colors=all_viz_colors,
                    K=origin_camintr,
                    rot=rot,
                    trans=trans,
                    image_sizes=[(width, height)],
                    mode="rgb",
                    min_depth=min_depth,
                    faces_per_pixel=2,
                ).cpu()
                # Render side view by rotating around average object point
                rot_center = torch.cat(obj_verts, 1).mean(1)
                rot_verts = ops3d.rot_points(all_verts, rot_center)
                side_rendres = batch_render(
                    rot_verts,
                    torch.flip(all_faces, (2, )),  # Compensate for - in verts
                    colors=all_viz_colors,
                    K=origin_camintr,
                    rot=rot,
                    trans=trans,
                    image_sizes=[(width, height)],
                    min_depth=min_depth,
                    mode="rgb",
                    faces_per_pixel=2,
                ).cpu()
            scene_res["scene_viz_rend"] = [
                cam_rendres,
                front_rendres,
                side_rendres,
            ]
        return scene_res
示例#5
0
    def preprocess_supervision(self, fit_infos, grab_objects=False):
        # Initialize tar reader
        tareader = TarReader()
        # sample_masks = []
        sample_verts = []
        sample_confs = []
        sample_imgs = []
        ref_hand_rends = []
        # Regions of interest containing hands and objects
        roi_bboxes = []
        roi_valid_masks = []
        # Crops of hand and object masks
        sample_hand_masks = []
        sample_objs_masks = []

        # Create dummy intrinsic camera for supervision rendering
        focal = 200
        camintr = np.array([[focal, 0, 456 // 2], [0, focal, 256 // 2],
                            [0, 0, 1]])
        camintr_th = torch.Tensor(camintr).unsqueeze(0)
        # Modelling hand color
        print("Preprocessing sequence")
        for fit_info in tqdm(fit_infos):
            img = tareader.read_tar_frame(fit_info["img_path"])
            img_size = img.shape[:2]  # height, width
            # img = cv2.imread(fit_info["img_path"])
            hand_infos = fit_info["hands"]
            human_verts = np.zeros((self.smplx_vertex_nb, 3))
            verts_confs = np.zeros((self.smplx_vertex_nb, ))

            # Get hand vertex refernces poses
            img_hand_verts = []
            img_hand_faces = []
            for side in hand_infos:
                hand_info = hand_infos[side]
                hand_verts = hand_info["verts"]
                # Aggregate hand vertices and faces for rendering
                img_hand_verts.append(
                    lift_verts(
                        torch.Tensor(hand_verts).unsqueeze(0), camintr_th))
                img_hand_faces.append(
                    torch.Tensor(hand_info["faces"]).unsqueeze(0))
                corresp = self.mano_corresp[f"{side}_hand"]
                human_verts[corresp] = hand_verts
                verts_confs[corresp] = 1

            has_hands = len(img_hand_verts) > 0
            # render reference hands
            if has_hands:
                img_hand_verts, img_hand_faces, _ = catmesh.batch_cat_meshes(
                    img_hand_verts, img_hand_faces)
                with torch.no_grad():
                    res = py3drendutils.batch_render(
                        img_hand_verts.cuda(),
                        img_hand_faces.cuda(),
                        faces_per_pixel=2,
                        color=(1, 0.4, 0.6),
                        K=camintr_th,
                        image_sizes=[(img_size[1], img_size[0])],
                        mode="rgb",
                        shading="soft",
                    )
                ref_hand_rends.append(npt.numpify(res[0, :, :, :3]))
                hand_mask = npt.numpify(res[0, :, :, 3])
            else:
                ref_hand_rends.append(np.zeros(img.shape) + 1)
                hand_mask = np.zeros((img.shape[:2]))
            obj_masks = fit_info["masks"]
            # GrabCut objects
            has_objs = len(obj_masks) > 0
            if has_objs:
                obj_masks_aggreg = (npt.numpify(torch.stack(obj_masks)).sum(0)
                                    > 0)
            else:
                obj_masks_aggreg = np.zeros_like(hand_mask)
            # Detect if some pseudo ground truth masks exist
            has_both_masks = (hand_mask.max() > 0) and (obj_masks_aggreg.max()
                                                        > 0)
            if has_both_masks:
                xs, ys = np.where((hand_mask + obj_masks_aggreg) > 0)
                # Compute region of interest which contains hands and objects
                roi_bbox = boxutils.squarify_box(
                    [xs.min(), ys.min(),
                     xs.max(), ys.max()], scale_factor=1.5)
            else:
                rad = min(img.shape[:2])
                roi_bbox = [0, 0, rad, rad]

            roi_bbox = [int(val) for val in roi_bbox]
            roi_bboxes.append(roi_bbox)
            img_crop = cropping.crop_cv2(img, roi_bbox, resize=self.crop_size)

            # Compute region of crop which belongs to original image (vs paddding)
            roi_valid_mask = cropping.crop_cv2(np.ones(img.shape[:2]),
                                               roi_bbox,
                                               resize=self.crop_size)
            roi_valid_masks.append(roi_valid_mask)

            # Crop hand and object image
            hand_mask_crop = (cropping.crop_cv2(
                hand_mask, roi_bbox, resize=self.crop_size) > 0).astype(np.int)
            objs_masks_crop = cropping.crop_cv2(
                obj_masks_aggreg.astype(np.int),
                roi_bbox,
                resize=self.crop_size,
            ).astype(np.int)

            # Remove object region from hand mask
            hand_mask_crop[objs_masks_crop > 0] = 0
            # Extract skeletons
            skel_objs_masks_crop = skeletonize(objs_masks_crop.astype(
                np.uint8))
            skel_hand_mask_crop = skeletonize(hand_mask_crop.astype(np.uint8))

            # Removing object region from hand can cancel out whole hand !
            if has_both_masks and hand_mask_crop.max():
                grabinfo = grabcut.grab_cut(
                    img_crop.astype(np.uint8),
                    mask=hand_mask_crop,
                    bbox=roi_bbox,
                    bgd_mask=skel_objs_masks_crop,
                    fgd_mask=skel_hand_mask_crop,
                    debug=self.debug,
                )
                hand_mask = grabinfo["grab_mask"]
                hand_mask[objs_masks_crop > 0] = 0
            else:
                hand_mask = hand_mask_crop
            sample_hand_masks.append(hand_mask)

            # Get crops of object masks
            obj_mask_crops = []
            for obj_mask in obj_masks:
                obj_mask_crop = cropping.crop_cv2(
                    npt.numpify(obj_mask).astype(np.int),
                    roi_bbox,
                    resize=self.crop_size,
                )
                skel_obj_mask_crop = skeletonize(obj_mask_crop.astype(
                    np.uint8))
                if grab_objects:
                    raise NotImplementedError(
                        "Maybe needs also the skeleton of other objects"
                        "to be labelled as background ?")
                    grabinfo = grabcut.grab_cut(
                        img_crop,
                        mask=obj_mask_crop,
                        bbox=roi_bbox,
                        bgd_mask=skel_hand_mask_crop,
                        fgd_mask=skel_obj_mask_crop,
                        debug=self.debug,
                    )
                    obj_mask_crop = grabinfo["grab_mask"]
                obj_mask_crops.append(obj_mask_crop)
            if len(obj_mask_crops):
                sample_objs_masks.append(np.stack(obj_mask_crops))
            else:
                sample_objs_masks.append(np.zeros((1, rad, rad)))

            # Remove object region from hand mask
            # sample_masks.append(torch.stack(fit_info["masks"]))
            sample_verts.append(human_verts)
            sample_confs.append(verts_confs)
            sample_imgs.append(img)
            verts = torch.Tensor(np.stack(sample_verts))

        links = [preprocess_links(info["links"]) for info in fit_infos]
        fit_data = {
            # "masks": torch.stack(sample_masks),
            "roi_bboxes": torch.Tensor(np.stack(roi_bboxes)),
            "roi_valid_masks": torch.Tensor(np.stack(roi_valid_masks)),
            "objs_masks_crops": torch.Tensor(np.stack(sample_objs_masks)),
            "hand_masks_crops": torch.Tensor(np.stack(sample_hand_masks)),
            "verts": verts,
            "verts_confs": torch.Tensor(np.stack(sample_confs)),
            "imgs": sample_imgs,
            "ref_hand_rends": ref_hand_rends,
            "links": links,
            "mano_corresp": self.mano_corresp,
        }
        return fit_data
示例#6
0
)

for iter_idx in tqdm(range(args.iter_nb)):
    obj_infos = [obj.forward() for obj in objects]
    verts = [obj_info["verts"] for obj_info in obj_infos]
    faces = [obj_info["faces"] for obj_info in obj_infos]
    colors = scene.get_segm_colors(verts, faces)
    all_verts, all_faces, all_colors = catmesh.batch_cat_meshes(
        verts, faces, colors
    )
    rendres = batch_render(
        all_verts,
        all_faces,
        K=camintr,
        image_sizes=[(300, 200)],
        shading="soft",
        mode="facecolor",
        # mode="rgb",
        face_colors=colors,
        color=(1, 0, 0),
        faces_per_pixel=args.faces_per_pixel,
    )
    row_nb = args.batch_size
    col_nb = 3
    diffs = (rendres - img_th[:, :, :, :])[:, :, :, :3]
    if args.loss_type == "l1":
        loss = diffs.abs().mean()
    if args.loss_type == "l2":
        loss = (diffs ** 2).sum(-1).mean()
    # loss = (rendres - img_th).abs().mean()
    optimizer.zero_grad()
    loss.backward()
示例#7
0
def ego_viz_old(
    pred_verts,
    pred_proj_verts,
    gt_proj_verts,
    vert_flags,
    imgs=None,
    fig_res=2,
    cam=None,
    faces=None,
    step_idx=0,
    save_folder="tmp",
):
    # Render predicted human
    render_verts = pred_verts.cuda()
    batch_size = len(render_verts)
    faces_th = (faces.unsqueeze(0).repeat(batch_size, 1,
                                          1).to(render_verts.device))
    camintr = (cam.get_camintr().to(render_verts.device).unsqueeze(0).repeat(
        batch_size, 1, 1))
    rot = (cam.get_camrot().unsqueeze(0).repeat(batch_size, 1,
                                                1).to(render_verts.device))
    img_size = (imgs[0].shape[1], imgs[0].shape[0])
    with torch.no_grad():
        rends = batch_render(
            render_verts,
            faces_th,
            K=camintr,
            rot=rot,
            image_sizes=[img_size for _ in range(batch_size)],
        )

    show_pred_verts = pred_verts.cpu().detach().numpy()

    row_nb = len(pred_verts)
    col_nb = 4
    fig, axes = plt.subplots(row_nb,
                             col_nb,
                             figsize=(int(col_nb * fig_res),
                                      int(row_nb * fig_res)))
    for row_idx in range(row_nb):
        ax = vizmp.get_axis(axes,
                            row_idx=row_idx,
                            col_idx=0,
                            row_nb=row_nb,
                            col_nb=col_nb)
        super_proj_verts = (gt_proj_verts[row_idx][(vert_flags[row_idx] >
                                                    0)].cpu().detach().numpy())
        super_pred_proj_verts = (pred_proj_verts[row_idx][(
            vert_flags[row_idx] > 0)].cpu().detach().numpy())
        if imgs is not None:
            ax.imshow(imgs[row_idx])
        point_nb = super_pred_proj_verts.shape[0]
        colors = cm.rainbow(np.linspace(0, 1, point_nb))
        ax.scatter(
            super_proj_verts[:, 0],
            super_proj_verts[:, 1],
            s=0.5,
            alpha=0.2,
            c="k",
        )
        ax.scatter(
            super_pred_proj_verts[:, 0],
            super_pred_proj_verts[:, 1],
            s=0.5,
            alpha=0.2,
            c=colors,
        )
        ax.axis("equal")

        row_pred_verts = show_pred_verts[row_idx]
        ax = vizmp.get_axis(axes,
                            row_idx=row_idx,
                            col_idx=1,
                            row_nb=row_nb,
                            col_nb=col_nb)
        ax.scatter(row_pred_verts[:, 0], row_pred_verts[:, 2], s=1)
        ax.axis("equal")

        ax = vizmp.get_axis(axes,
                            row_idx=row_idx,
                            col_idx=2,
                            row_nb=row_nb,
                            col_nb=col_nb)
        ax.scatter(row_pred_verts[:, 1], row_pred_verts[:, 2], s=1)
        ax.axis("equal")

        ax = vizmp.get_axis(axes,
                            row_idx=row_idx,
                            col_idx=3,
                            row_nb=row_nb,
                            col_nb=col_nb)
        ax.imshow(imgs[row_idx])
        ax.imshow(rends[row_idx].sum(-1).cpu().numpy(), alpha=0.5)

    os.makedirs(save_folder, exist_ok=True)
    save_path = os.path.join(save_folder, f"tmp_{step_idx:04d}.png")
    fig.savefig(save_path)
    print(f"Saved to {save_path}")