Beispiel #1
0
def _plot_sample_light(fig, ax, ib, count, data_dict):
    plt.cla()
    ax.set_xlim(_X_LIM[0], _X_LIM[1])
    ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.set_aspect("equal")
    # ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.")

    # scan and cls label
    scan_r = data_dict["scans"][ib][-1]
    scan_phi = data_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    ax.scatter(scan_x, scan_y, s=0.5, c="blue")

    # annotation
    ann = data_dict["dets_wp"][ib]
    ann_valid_mask = data_dict["anns_valid_mask"][ib]
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y, valid in zip(det_x, det_y, ann_valid_mask):
            if valid:
                # c = plt.Circle((x, y), radius=0.1, color="red", fill=True)
                c = plt.Circle((x, y), radius=0.4, color="red", fill=False)
                ax.add_artist(c)
Beispiel #2
0
def _plot_sample(fig, ax, ib, count, data_dict):
    plt.cla()
    ax.set_xlim(_X_LIM[0], _X_LIM[1])
    ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.set_aspect("equal")
    ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.")

    # scan and cls label
    scan_r = data_dict["scans"][ib][-1]
    scan_phi = data_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)

    target_cls = data_dict["target_cls"][ib]
    ax.scatter(scan_x[target_cls == -2],
               scan_y[target_cls == -2],
               s=1,
               c="yellow")
    ax.scatter(scan_x[target_cls == -1],
               scan_y[target_cls == -1],
               s=1,
               c="orange")
    ax.scatter(scan_x[target_cls == 0],
               scan_y[target_cls == 0],
               s=1,
               c="black")
    ax.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=1, c="green")

    # annotation
    ann = data_dict["dets_wp"][ib]
    ann_valid_mask = data_dict["anns_valid_mask"][ib]
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y, valid in zip(det_x, det_y, ann_valid_mask):
            c = "blue" if valid else "orange"
            c = plt.Circle((x, y), radius=0.4, color=c, fill=False)
            ax.add_artist(c)

    # reg label
    target_reg = data_dict["target_reg"][ib]
    dets_r, dets_phi = u.canonical_to_global(scan_r, scan_phi,
                                             target_reg[:, 0], target_reg[:,
                                                                          1])
    dets_r = dets_r[target_cls > 0]
    dets_phi = dets_phi[target_cls > 0]
    dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi)
    ax.scatter(dets_x, dets_y, s=10, c="red")
def _plot_annotation(ann, ax, color, radius):
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y in zip(det_x, det_y):
            c = plt.Circle((x, y), radius=radius, color=color, fill=False)
            ax.add_artist(c)
Beispiel #4
0
def plot_one_frame(
    batch_dict,
    frame_idx,
    pred_cls=None,
    pred_reg=None,
    dets_cls=None,
    dets_xy=None,
    xlim=_X_LIM,
    ylim=_Y_LIM,
):
    """Plot one frame from a batch, specified by frame_idx.

    Returns:
        fig: figure handle
        ax: axis handle
    """
    fig, ax = _create_figure("", xlim, ylim)

    # scan and cls label
    scan_r = batch_dict["scans"][frame_idx][-1]
    scan_phi = batch_dict["scan_phi"][frame_idx]
    target_cls = batch_dict["target_cls"][frame_idx]
    _plot_scan(ax, scan_r, scan_phi, target_cls, s=1)

    # annotation
    ann = batch_dict["dets_wp"][frame_idx]
    ann_valid_mask = batch_dict["anns_valid_mask"][frame_idx]
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y, valid in zip(det_x, det_y, ann_valid_mask):
            c = "blue" if valid else "orange"
            c = plt.Circle((x, y), radius=0.4, color=c, fill=False)
            ax.add_artist(c)

    # regression target
    target_reg = batch_dict["target_reg"][frame_idx]
    _plot_target(ax,
                 target_reg,
                 target_cls > 0,
                 scan_r,
                 scan_phi,
                 s=10,
                 c="blue")

    # regression result
    if dets_xy is not None and dets_cls is not None:
        _plot_detection(ax, dets_cls, dets_xy, s=40, color_dim=1)

    if pred_cls is not None and pred_reg is not None:
        _plot_prediction(ax,
                         pred_cls,
                         pred_reg,
                         scan_r,
                         scan_phi,
                         s=2,
                         color_dim=1)

    return fig, ax
Beispiel #5
0
def _plot_scan(ax, scan_r, scan_phi, target_cls, s):
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    ax.scatter(scan_x[target_cls < 0], scan_y[target_cls < 0], s=s, c="orange")
    ax.scatter(scan_x[target_cls == 0],
               scan_y[target_cls == 0],
               s=s,
               c="black")
    ax.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=s, c="green")
Beispiel #6
0
def _plot_target(ax, target_reg, target_flag, scan_r, scan_phi, s, c):
    dets_r, dets_phi = u.canonical_to_global(scan_r, scan_phi,
                                             target_reg[:, 0], target_reg[:,
                                                                          1])
    dets_r = dets_r[target_flag]
    dets_phi = dets_phi[target_flag]
    dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi)
    ax.scatter(dets_x, dets_y, s=s, c=c)
def _plot_sequence():
    drow_handle = DROWHandle(
        split="train",
        cfg={
            "num_scans": 1,
            "scan_stride": 1,
            "data_dir": "./data/DROWv2-data"
        },
    )

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False

    if _INTERACTIVE:

        def p(event):
            nonlocal _break
            _break = True

        fig.canvas.mpl_connect("key_press_event", p)
    else:
        if os.path.exists(_SAVE_DIR):
            shutil.rmtree(_SAVE_DIR)
        os.makedirs(_SAVE_DIR)

    for i, data_dict in enumerate(drow_handle):
        if _break:
            break

        scan_x, scan_y = u.rphi_to_xy(data_dict["scans"][-1],
                                      data_dict["scan_phi"])

        plt.cla()
        ax.set_aspect("equal")
        ax.set_xlim(_X_LIM[0], _X_LIM[1])
        ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
        ax.set_xlabel("x [m]")
        ax.set_ylabel("y [m]")
        ax.set_title(f"Frame {data_dict['idx']}. Press any key to exit.")
        # ax.axis("off")

        ax.scatter(scan_x, scan_y, s=1, c="black")

        _plot_annotation(data_dict["dets_wc"], ax, "red", 0.6)
        _plot_annotation(data_dict["dets_wa"], ax, "green", 0.4)
        _plot_annotation(data_dict["dets_wp"], ax, "blue", 0.35)

        if _INTERACTIVE:
            plt.pause(0.1)
        else:
            plt.savefig(
                os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png"))

    if _INTERACTIVE:
        plt.show()
Beispiel #8
0
def _plot_annotation_detr(ax, anns, radius, color, linestyle="-"):
    if len(anns) == 0:
        return

    det_x, det_y = u.rphi_to_xy(anns[0], anns[1])
    for x, y in zip(det_x, det_y):
        c = plt.Circle((x, y),
                       radius=radius,
                       color=color,
                       fill=False,
                       linestyle=linestyle)
        ax.add_artist(c)
def _model_eval_fn(model, batch_dict):
    _, tb_dict, rtn_dict = _model_fn(model, batch_dict)

    pred_cls = torch.sigmoid(rtn_dict["pred_cls"]).data.cpu().numpy()
    pred_reg = rtn_dict["pred_reg"].data.cpu().numpy()

    # # DEBUG use perfect predictions
    # pred_cls = batch_dict["target_cls"]
    # pred_cls[pred_cls < 0] = 1
    # pred_reg = batch_dict["target_reg"]

    fig_dict = {}
    file_dict = {}

    # postprocess network prediction to get detection
    scans = batch_dict["scans"]
    scan_phi = batch_dict["scan_phi"]
    for ib in range(len(scans)):
        # store detection, which will be used by _model_eval_collate_fn to compute AP
        dets_xy, dets_cls, _ = u.nms_predicted_center(scans[ib][-1],
                                                      scan_phi[ib],
                                                      pred_cls[ib],
                                                      pred_reg[ib])
        frame_id = "%06d" % batch_dict['frame_id'][ib]
        sequence = batch_dict["sequence"][ib]

        # save detection results for evaluation
        det_str = pru.drow_detection_to_kitti_string(dets_xy, dets_cls, None)
        file_dict["detections/" + str(sequence) + "/" +
                  str(frame_id)] = det_str

        # save corresponding groundtruth for evaluation
        anns_rphi = batch_dict["dets_wp"][ib]
        if len(anns_rphi) > 0:
            anns_rphi = np.array(anns_rphi, dtype=np.float32)
            gts_xy = np.stack(u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1]),
                              axis=1)
            gts_occluded = np.logical_not(
                batch_dict["anns_valid_mask"][ib]).astype(np.int)
            gts_str = pru.drow_detection_to_kitti_string(
                gts_xy, None, gts_occluded)
            file_dict["groundtruth/{}/{}".format(sequence, frame_id)] = gts_str
        else:
            file_dict["groundtruth/{}/{}".format(sequence, frame_id)] = ""

        # TODO When to plot
        if _PLOTTING:
            fig, ax = plot_one_frame(batch_dict, ib, pred_cls[ib],
                                     pred_reg[ib], dets_cls, dets_xy)
            fig_dict["figs/{}/{}".format(sequence, frame_id)] = (fig, ax)

    return tb_dict, file_dict, fig_dict
Beispiel #10
0
def _plot_sequence():
    jrdb_handle = JRDBHandle(
        split="train",
        cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1},
    )

    fig = plt.figure(figsize=(20, 10))
    gs = GridSpec(3, 2, figure=fig)

    ax_im = fig.add_subplot(gs[0, :])
    ax_bev = fig.add_subplot(gs[1:, 1])
    ax_fpv_xz = fig.add_subplot(gs[1, 0])
    ax_fpv_yz = fig.add_subplot(gs[2, 0])

    color_pool = np.random.uniform(size=(100, 3))

    _break = False

    if _INTERACTIVE:

        def p(event):
            nonlocal _break
            _break = True

        fig.canvas.mpl_connect("key_press_event", p)
    else:
        if os.path.exists(_SAVE_DIR):
            shutil.rmtree(_SAVE_DIR)
        os.makedirs(_SAVE_DIR)

    for i, data_dict in enumerate(jrdb_handle):
        if _break:
            break

        # lidar
        pc_xyz_upper = jt.transform_pts_upper_velodyne_to_base(
            data_dict["pc_data"]["upper_velodyne"]
        )
        pc_xyz_lower = jt.transform_pts_lower_velodyne_to_base(
            data_dict["pc_data"]["lower_velodyne"]
        )

        # labels
        boxes, label_ids = [], []
        for ann in data_dict["pc_anns"]:
            jrdb_handle.box_is_on_ground(ann)
            boxes.append(jt.box_from_jrdb(ann["box"]))
            label_ids.append(int(ann["label_id"].split(":")[-1]) % len(color_pool))

        # laser
        laser_r = data_dict["laser_data"][-1]
        laser_phi = data_dict["laser_grid"]
        laser_z = data_dict["laser_z"]
        laser_x, laser_y = u.rphi_to_xy(laser_r, laser_phi)
        pc_xyz_laser = jt.transform_pts_laser_to_base(
            np.stack((laser_x, laser_y, laser_z), axis=0)
        )

        # BEV
        ax_bev.cla()
        ax_bev.set_aspect("equal")
        ax_bev.set_xlim(_XY_LIM[0], _XY_LIM[1])
        ax_bev.set_ylim(_XY_LIM[0], _XY_LIM[1])
        ax_bev.set_title(f"Frame {data_dict['idx']}. Press any key to exit.")
        ax_bev.set_xlabel("x [m]")
        ax_bev.set_ylabel("y [m]")
        # ax_bev.axis("off")

        for rgb_dim, pc_xyz in zip(
            (2, 1, 0), (pc_xyz_upper, pc_xyz_lower, pc_xyz_laser)
        ):
            ax_bev.scatter(pc_xyz[0], pc_xyz[1], s=1, c=_get_pts_color(pc_xyz, rgb_dim))

        for label_id, box in zip(label_ids, boxes):
            box.draw_bev(ax_bev, c=color_pool[label_id])

        # side view
        for dim, ax_fpv in zip((0, 1), (ax_fpv_xz, ax_fpv_yz)):
            ax_fpv.cla()
            ax_fpv.set_aspect("equal")
            ax_fpv.set_xlim(_XY_LIM[0], _XY_LIM[1])
            ax_fpv.set_ylim(_Z_LIM[0], _Z_LIM[1])
            ax_fpv.set_title(f"Frame {data_dict['idx']}. Press any key to exit.")
            ax_fpv.set_xlabel("x [m]" if dim == 0 else "y [m]")
            ax_fpv.set_ylabel("z [m]")
            # ax_fpv.axis("off")

            for rgb_dim, pc_xyz in zip(
                (2, 1, 0), (pc_xyz_upper, pc_xyz_lower, pc_xyz_laser)
            ):
                ax_fpv.scatter(
                    pc_xyz[dim], pc_xyz[2], s=1, c=_get_pts_color(pc_xyz, rgb_dim)
                )

            for label_id, box in zip(label_ids, boxes):
                box.draw_fpv(ax_fpv, dim=dim, c=color_pool[label_id])

        # image
        ax_im.cla()
        ax_im.axis("off")
        ax_im.imshow(data_dict["im_data"]["stitched_image0"])

        # detection bounding box
        for box_dict in data_dict["im_dets"]:
            x0, y0, w, h = box_dict["box"]
            verts = np.array(
                [(x0, y0), (x0, y0 + h), (x0 + w, y0 + h), (x0 + w, y0), (x0, y0)]
            )
            c = max(float(box_dict["score"]) - 0.5, 0) * 2.0
            ax_im.plot(verts[:, 0], verts[:, 1], c=(1.0 - c, 1.0 - c, 1.0))

        # laser points on image
        p_xy, ib_mask = jt.transform_pts_base_to_stitched_im(pc_xyz_laser)
        ax_im.scatter(
            p_xy[0, ib_mask],
            p_xy[1, ib_mask],
            s=1,
            c=_get_pts_color(pc_xyz_laser[:, ib_mask], dim=0),
        )

        if _INTERACTIVE:
            plt.pause(0.1)
        else:
            plt.savefig(os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png"))

    if _INTERACTIVE:
        plt.show()
Beispiel #11
0
def generate_pseudo_labels():
    with open("./base_dr_spaam_jrdb_cfg.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg["dataset"]["pseudo_label"] = True
    cfg["dataset"]["pl_correction_level"] = 0

    test_loader = get_dataloader(
        split=_SPLIT,
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )

    if os.path.exists(_SAVE_DIR):
        shutil.rmtree(_SAVE_DIR)
        time.sleep(1.0)
    os.makedirs(_SAVE_DIR)

    sequences_tp_fp_tn_fn = {}  # for computing true positive and negative rate

    # generate pseudo labels for all sample
    for count, batch_dict in enumerate(tqdm(test_loader)):
        if count >= _MAX_COUNT:
            break

        for ib in range(len(batch_dict["input"])):
            frame_id = f"{batch_dict['frame_id'][ib]:06d}"
            sequence = batch_dict["sequence"][ib]

            if count % _PLOTTING_INTERVAL == 0:
                # visualize the whole frame
                _plot_frame(batch_dict, ib)

                # visualize each pseudo labels
                _plot_pseudo_labels(batch_dict, ib)

            # save pseudo labels as detection results for evaluation
            pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
            pl_str = (
                pru.drow_detection_to_kitti_string(pl_xy, None, None)
                if len(pl_xy) > 0
                else ""
            )

            pl_file = os.path.join(_SAVE_DIR, f"detections/{sequence}/{frame_id}.txt")
            _write_file_make_dir(pl_file, pl_str)

            # save groundtruth
            anns_rphi = batch_dict["dets_wp"][ib]
            if len(anns_rphi) > 0:
                anns_rphi = np.array(anns_rphi, dtype=np.float32)
                gts_xy = np.stack(
                    u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1]), axis=1
                )
                gts_occluded = np.logical_not(batch_dict["anns_valid_mask"][ib]).astype(
                    np.int
                )
                gts_str = pru.drow_detection_to_kitti_string(gts_xy, None, gts_occluded)
            else:
                gts_str = ""

            gts_file = os.path.join(_SAVE_DIR, f"groundtruth/{sequence}/{frame_id}.txt")
            _write_file_make_dir(gts_file, gts_str)

            # compute true positive and negative rate
            target_cls = batch_dict["target_cls"][ib]
            target_cls_gt = batch_dict["target_cls_real"][ib]

            tn = np.sum(
                np.logical_or(
                    np.logical_and(target_cls == 0, target_cls_gt == 0),
                    np.logical_and(target_cls == 0, target_cls_gt == -1),
                )
            )
            fn = np.sum(target_cls == 0) - tn
            tp = np.sum(
                np.logical_or(
                    np.logical_and(target_cls == 1, target_cls_gt == 1),
                    np.logical_and(target_cls == 1, target_cls_gt == -1),
                )
            )
            fp = np.sum(target_cls == 1) - tp

            if sequence in sequences_tp_fp_tn_fn.keys():
                tp0, fp0, tn0, fn0 = sequences_tp_fp_tn_fn[sequence]
                sequences_tp_fp_tn_fn[sequence] = (
                    tp + tp0,
                    fp + fp0,
                    tn + tn0,
                    fn + fn0,
                )
            else:
                sequences_tp_fp_tn_fn[sequence] = (tp, fp, tn, fn)

    # write sequence statistics to file
    for sequence, (tp, fp, tn, fn) in sequences_tp_fp_tn_fn.items():
        st_file = os.path.join(_SAVE_DIR, f"evaluation/{sequence}/tp_fp_tn_fn.txt")
        _write_file_make_dir(st_file, f"{tp},{fp},{tn},{fn}")
Beispiel #12
0
def _plot_prediction(ax, pred_cls, pred_reg, scan_r, scan_phi, s, color_dim):
    pred_r, pred_phi = u.canonical_to_global(scan_r, scan_phi, pred_reg[:, 0],
                                             pred_reg[:, 1])
    pred_x, pred_y = u.rphi_to_xy(pred_r, pred_phi)
    pred_color = _cls_to_color(pred_cls, color_dim=color_dim)
    ax.scatter(pred_x, pred_y, s=s, c=pred_color)
Beispiel #13
0
def _get_regression_target(scan_rphi, dets_rphi, person_radius_small,
                           person_radius_large, min_close_points):
    """Generate classification and regression label.

    Args:
        scan_rphi (np.ndarray[2, N]): Scan points in polar coordinate
        dets_rphi (np.ndarray[2, M]): Annotated person centers in polar coordinate
        person_radius_small (float): Points less than this distance away
            from an annotation is assigned to that annotation and marked as fg.
        person_radius_large (float): Points with no annotation smaller
            than this distance is marked as bg.
        min_close_points (int): Annotations with supportive points fewer than this
            value is marked as invalid. Supportive points are those within the small
            radius.

    Returns:
        target_cls (np.ndarray[N]): Classification label, 1=fg, 0=bg, -1=ignore
        target_reg (np.ndarray[N, 2]): Regression label
        anns_valid_mask (np.ndarray[M])
    """
    N = scan_rphi.shape[1]

    # no annotation in this frame
    if len(dets_rphi) == 0:
        return np.zeros(N, dtype=np.int64), np.zeros((N, 2),
                                                     dtype=np.float32), []

    scan_xy = np.stack(u.rphi_to_xy(scan_rphi[0], scan_rphi[1]), axis=0)
    dets_xy = np.stack(u.rphi_to_xy(dets_rphi[0], dets_rphi[1]), axis=0)

    dist_scan_dets = np.hypot(
        scan_xy[0].reshape(1, -1) - dets_xy[0].reshape(-1, 1),
        scan_xy[1].reshape(1, -1) - dets_xy[1].reshape(-1, 1),
    )  # (M, N) pairwise distance between scan and detections

    # mark out annotations that has too few scan points
    anns_valid_mask = (
        np.sum(dist_scan_dets < person_radius_small, axis=1) > min_close_points
    )  # (M, )

    # for each point, find the distance to its closest annotation
    argmin_dist_scan_dets = np.argmin(dist_scan_dets, axis=0)  # (N, )
    min_dist_scan_dets = dist_scan_dets[argmin_dist_scan_dets, np.arange(N)]

    # points within small radius, whose corresponding annotation is valid, is marked
    # as foreground
    target_cls = -1 * np.ones(N, dtype=np.int64)
    fg_mask = np.logical_and(anns_valid_mask[argmin_dist_scan_dets],
                             min_dist_scan_dets < person_radius_small)
    target_cls[fg_mask] = 1
    target_cls[min_dist_scan_dets > person_radius_large] = 0

    # regression target
    dets_matched_rphi = dets_rphi[:, argmin_dist_scan_dets]
    target_reg = np.stack(
        u.global_to_canonical(scan_rphi[0], scan_rphi[1], dets_matched_rphi[0],
                              dets_matched_rphi[1]),
        axis=1,
    )

    return target_cls, target_reg, anns_valid_mask
def play_sequence(seq_name, tracking=False):
    # scans
    scans_data = np.genfromtxt(seq_name, delimiter=',')
    scans_t = scans_data[:, 1]
    scans = scans_data[:, 2:]
    scan_phi = u.get_laser_phi()

    # odometry, used only for plotting
    odo_name = seq_name[:-3] + 'odom2'
    odos = np.genfromtxt(odo_name, delimiter=',')
    odos_t = odos[:, 1]
    odos_phi = odos[:, 4]

    # detector
    detector = DrSpaamDetector(num_pts=450,
                               ang_inc_degree=0.5,
                               tracking=tracking,
                               gpu=True,
                               ckpt=default_ckpts)
    # detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)

    # scanner location
    rad_tmp = 0.5 * np.ones(len(scan_phi), dtype=np.float)
    xy_scanner = u.rphi_to_xy(rad_tmp, scan_phi)
    xy_scanner = np.stack(xy_scanner[::-1], axis=1)

    # plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False
    _pause = False

    def p(event):
        nonlocal _break, _pause
        if event.key == 'escape':
            _break = True
        if event.key == ' ':
            _pause = not _pause

    fig.canvas.mpl_connect('key_press_event', p)

    # video sequence
    odo_idx = 0
    for i in range(len(scans)):
        # for i in range(0, len(scans), 20):
        plt.cla()

        ax.set_aspect('equal')
        ax.set_xlim(-15, 15)
        ax.set_ylim(-15, 15)

        # ax.set_title('Frame: %s' % i)
        ax.set_title('Press escape key to exit.')
        ax.axis("off")

        # find matching odometry
        while odo_idx < len(odos_t) - 1 and odos_t[odo_idx] < scans_t[i]:
            odo_idx += 1
        odo_phi = odos_phi[odo_idx]
        odo_rot = np.array([[np.cos(odo_phi), -np.sin(odo_phi)],
                            [np.sin(odo_phi), np.cos(odo_phi)]],
                           dtype=np.float32)

        # plot scanner location
        xy_scanner_rot = np.matmul(xy_scanner, odo_rot.T)
        ax.plot(xy_scanner_rot[:, 0], xy_scanner_rot[:, 1], c='black')
        ax.plot((0, xy_scanner_rot[0, 0] * 1.0),
                (0, xy_scanner_rot[0, 1] * 1.0),
                c='black')
        ax.plot((0, xy_scanner_rot[-1, 0] * 1.0),
                (0, xy_scanner_rot[-1, 1] * 1.0),
                c='black')

        # plot points
        scan = scans[i]
        scan_y, scan_x = u.rphi_to_xy(scan, scan_phi + odo_phi)
        ax.scatter(scan_x, scan_y, s=1, c='blue')

        # inference
        dets_xy, dets_cls = detector.detect(scan)

        # plot detection
        dets_xy_rot = np.matmul(dets_xy, odo_rot.T)
        cls_thresh = 0.3
        for j in range(len(dets_xy)):
            if dets_cls[j] < cls_thresh:
                continue
            # c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False)
            c = plt.Circle(dets_xy_rot[j],
                           radius=0.5,
                           color='r',
                           fill=False,
                           linewidth=2)
            ax.add_artist(c)

        # plot track
        if tracking:
            cls_thresh = 0.2
            tracks, tracks_cls = detector.get_tracklets()
            for t, tc in zip(tracks, tracks_cls):
                if tc >= cls_thresh and len(t) > 1:
                    t_rot = np.matmul(t, odo_rot.T)
                    ax.plot(t_rot[:, 0], t_rot[:, 1], color='g', linewidth=2)

        # plt.savefig('/home/dan/tmp/det_img/frame_%04d.png' % i)

        plt.pause(0.001)

        if _break:
            break
        if _pause:
            plt.pause(1)
Beispiel #15
0
    def _get_sample(self, idx):
        data_dict = self.__handle[idx]

        # DROW defines laser frame as x-forward, y-right, z-downward
        # JRDB defines laser frame as x-forward, y-left, z-upward
        # Use DROW frame for DR-SPAAM or DROW3

        # equivalent of flipping y axis (inversing laser phi angle)
        data_dict["laser_data"] = data_dict["laser_data"][:, ::-1]
        scan_rphi = np.stack(
            (data_dict["laser_data"][-1], data_dict["laser_grid"]), axis=0)

        # get annotation in laser frame
        ann_xyz = [(ann["box"]["cx"], ann["box"]["cy"], ann["box"]["cz"])
                   for ann in data_dict["pc_anns"]]
        if len(ann_xyz) > 0:
            ann_xyz = np.array(ann_xyz, dtype=np.float32).T
            ann_xyz = jt.transform_pts_base_to_laser(ann_xyz)
            ann_xyz[1] = -ann_xyz[1]  # to DROW frame
            dets_rphi = np.stack(u.xy_to_rphi(ann_xyz[0], ann_xyz[1]), axis=0)
        else:
            dets_rphi = []

        # regression target
        target_cls, target_reg, anns_valid_mask = _get_regression_target(
            scan_rphi,
            dets_rphi,
            person_radius_small=0.4,
            person_radius_large=0.8,
            min_close_points=5,
        )

        data_dict["target_cls"] = target_cls
        data_dict["target_reg"] = target_reg
        data_dict["anns_valid_mask"] = anns_valid_mask

        # regression target from pseudo labels
        if self._pseudo_label:
            # get pixels of laser points projected on image
            scan_x, scan_y = u.rphi_to_xy(scan_rphi[0], scan_rphi[1])
            scan_y = -scan_y  # convert DROW frame to JRDB laser frame
            scan_xyz_laser = np.stack((scan_x, scan_y, data_dict["laser_z"]),
                                      axis=0)
            scan_pixel_xy, _ = jt.transform_pts_laser_to_stitched_im(
                scan_xyz_laser)

            # get detection boxes
            boxes = []
            boxes_confs = []
            for box_dict in data_dict["im_dets"]:
                x0, y0, w, h = box_dict["box"]
                boxes.append((x0, y0, x0 + w, y0 + h))
                boxes_confs.append(box_dict["score"])
            ########
            # NOTE for ablation, using 2D annotation to generate pseudo labels
            # for box_dict in data_dict["im_anns"]:
            #     x0, y0, w, h = box_dict["box"]
            #     boxes.append((x0, y0, x0 + w, y0 + h))
            #     boxes_confs.append(1.0)
            ########
            boxes = np.array(boxes, dtype=np.float32)
            boxes_confs = np.array(boxes_confs, dtype=np.float32)

            # pseudo label
            pl_xy, pl_boxes, pl_neg_mask = u.generate_pseudo_labels(
                scan_rphi[0], scan_rphi[1], scan_pixel_xy, boxes, boxes_confs)
            (
                target_cls_pseudo,
                target_reg_pseudo,
            ) = _get_regression_target_from_pseudo_labels(
                scan_rphi,
                pl_xy,
                pl_neg_mask,
                person_radius_small=0.4,
                person_radius_large=0.8,
                min_close_points=5,
                pl_correction_level=self._pl_correction_level,
                target_cls_annotated=data_dict["target_cls"],
                target_reg_annotated=data_dict["target_reg"],
            )

            data_dict["pseudo_label_loc_xy"] = pl_xy
            data_dict["pseudo_label_boxes"] = pl_boxes

            # still keep the original target for debugging purpose
            data_dict["target_cls_real"] = data_dict["target_cls"]
            data_dict["target_reg_real"] = data_dict["target_reg"]
            data_dict["target_cls"] = target_cls_pseudo
            data_dict["target_reg"] = target_reg_pseudo

        # to be consistent with DROWDataset in order to use the same evaluation function
        dets_wp = []
        for i in range(dets_rphi.shape[1]):
            dets_wp.append((dets_rphi[0, i], dets_rphi[1, i]))
        data_dict["dets_wp"] = dets_wp
        data_dict["scans"] = data_dict["laser_data"]
        data_dict["scan_phi"] = data_dict["laser_grid"]

        if self._augment_data:
            data_dict = u.data_augmentation(data_dict)

        data_dict["input"] = u.scans_to_cutout(
            data_dict["laser_data"],
            data_dict["laser_grid"],
            stride=1,
            **self._cutout_kwargs,
        )

        return data_dict
def _test_detr_dataloader():
    with open("./tests/test.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg["dataset"]["DataHandle"]["tracking"] = True
    cfg["dataset"]["DataHandle"]["num_scans"] = 1

    test_loader = get_dataloader(
        split="train",
        batch_size=8,
        num_workers=1,
        shuffle=True,
        dataset_cfg=cfg["dataset"],
    )

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False

    if _INTERACTIVE:

        def p(event):
            nonlocal _break
            _break = True

        fig.canvas.mpl_connect("key_press_event", p)
    else:
        if os.path.exists(_SAVE_DIR):
            shutil.rmtree(_SAVE_DIR)
        os.makedirs(_SAVE_DIR)

    for count, data_dict in enumerate(test_loader):
        for ib in range(len(data_dict["input"])):
            fr_idx = data_dict["frame_dict_curr"][ib]["idx"]

            plt.cla()
            ax.set_xlim(_X_LIM[0], _X_LIM[1])
            ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
            ax.set_xlabel("x [m]")
            ax.set_ylabel("y [m]")
            ax.set_aspect("equal")
            ax.set_title(f"Frame {fr_idx}. Press any key to exit.")

            # scan and cls label
            scan_r = data_dict["frame_dict_curr"][ib]["laser_data"][-1]
            scan_phi = data_dict["frame_dict_curr"][ib]["laser_grid"]
            scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)

            target_cls = data_dict["target_cls"][ib]
            ax.scatter(scan_x[target_cls < 0],
                       scan_y[target_cls < 0],
                       s=1,
                       c="orange")
            ax.scatter(scan_x[target_cls == 0],
                       scan_y[target_cls == 0],
                       s=1,
                       c="black")
            ax.scatter(scan_x[target_cls > 0],
                       scan_y[target_cls > 0],
                       s=1,
                       c="green")

            # annotation for tracking
            anns_tracking = data_dict["frame_dict_curr"][ib]["dets_rphi_prev"]
            anns_tracking_mask = data_dict["anns_tracking_mask"][ib]
            anns_tracking = anns_tracking[:, anns_tracking_mask]
            if len(anns_tracking) > 0:
                det_x, det_y = u.rphi_to_xy(anns_tracking[0], anns_tracking[1])
                for x, y in zip(det_x, det_y):
                    c = plt.Circle((x, y),
                                   radius=0.5,
                                   color="gray",
                                   fill=False,
                                   linestyle="--")
                    ax.add_artist(c)

            # annotation
            anns = data_dict["frame_dict_curr"][ib]["dets_rphi"]
            anns_valid_mask = data_dict["anns_valid_mask"][ib]
            if len(anns) > 0:
                det_x, det_y = u.rphi_to_xy(anns[0], anns[1])
                for x, y, valid in zip(det_x, det_y, anns_valid_mask):
                    c = "blue" if valid else "orange"
                    c = plt.Circle((x, y), radius=0.4, color=c, fill=False)
                    ax.add_artist(c)

            # reg label for previous frame
            target_reg_prev = data_dict["target_reg_prev"][ib]
            target_tracking_flag = data_dict["target_tracking_flag"][ib]
            dets_r_prev, dets_phi_prev = u.canonical_to_global(
                scan_r, scan_phi, target_reg_prev[:, 0], target_reg_prev[:, 1])
            dets_r_prev = dets_r_prev[target_tracking_flag]
            dets_phi_prev = dets_phi_prev[target_tracking_flag]
            dets_x_prev, dets_y_prev = u.rphi_to_xy(dets_r_prev, dets_phi_prev)
            ax.scatter(dets_x_prev, dets_y_prev, s=25, c="gray")

            # reg label for current frame
            target_reg = data_dict["target_reg"][ib]
            dets_r, dets_phi = u.canonical_to_global(scan_r, scan_phi,
                                                     target_reg[:, 0],
                                                     target_reg[:, 1])
            dets_r = dets_r[target_cls > 0]
            dets_phi = dets_phi[target_cls > 0]
            dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi)
            ax.scatter(dets_x, dets_y, s=10, c="red")

            if _INTERACTIVE:
                plt.pause(0.1)
            else:
                plt.savefig(
                    os.path.join(
                        _SAVE_DIR,
                        f"b{count:03}s{ib:02}f{fr_idx:04}.png",
                    ))

    if _INTERACTIVE:
        plt.show()
Beispiel #17
0
def test_detector():
    data_handle = JRDBHandle(
        split="train",
        cfg={
            "data_dir": "./data/JRDB",
            "num_scans": 10,
            "scan_stride": 1
        },
    )

    # ckpt_file = "/home/jia/ckpts/ckpt_jrdb_ann_drow3_e40.pth"
    # d = Detector(
    #     ckpt_file, model="DROW3", gpu=True, stride=1, panoramic_scan=True
    # )

    ckpt_file = "/home/jia/ckpts/ckpt_jrdb_ann_dr_spaam_e20.pth"
    d = Detector(ckpt_file,
                 model="DR-SPAAM",
                 gpu=True,
                 stride=1,
                 panoramic_scan=True)

    d.set_laser_fov(360)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False

    if _INTERACTIVE:

        def p(event):
            nonlocal _break
            _break = True

        fig.canvas.mpl_connect("key_press_event", p)
    else:
        if os.path.exists(_SAVE_DIR):
            shutil.rmtree(_SAVE_DIR)
        os.makedirs(_SAVE_DIR)

    for i, data_dict in enumerate(data_handle):
        if _break:
            break

        # plot scans
        scan_r = data_dict["laser_data"][-1, ::-1]  # to DROW frame
        scan_x, scan_y = u.rphi_to_xy(scan_r, data_dict["laser_grid"])

        plt.cla()
        ax.set_aspect("equal")
        ax.set_xlim(_X_LIM[0], _X_LIM[1])
        ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
        ax.set_xlabel("x [m]")
        ax.set_ylabel("y [m]")
        ax.set_title(f"Frame {data_dict['idx']}. Press any key to exit.")
        # ax.axis("off")

        ax.scatter(scan_x, scan_y, s=1, c="black")

        # plot annotation
        ann_xyz = [(ann["box"]["cx"], ann["box"]["cy"], ann["box"]["cz"])
                   for ann in data_dict["pc_anns"]]
        if len(ann_xyz) > 0:
            ann_xyz = np.array(ann_xyz, dtype=np.float32).T
            ann_xyz = jt.transform_pts_base_to_laser(ann_xyz)
            ann_xyz[1] = -ann_xyz[1]  # to DROW frame
            for xyz in ann_xyz.T:
                c = plt.Circle(
                    (xyz[0], xyz[1]),
                    radius=0.4,
                    color="red",
                    fill=False,
                    linestyle="--",
                )
                ax.add_artist(c)

        # plot detection
        dets_xy, dets_cls, _ = d(scan_r)
        dets_cls_norm = np.clip(dets_cls, 0, 0.3) / 0.3
        for xy, cls_norm in zip(dets_xy, dets_cls_norm):
            color = (1.0 - cls_norm, 1.0, 1.0 - cls_norm)
            c = plt.Circle((xy[0], xy[1]),
                           radius=0.4,
                           color=color,
                           fill=False,
                           linestyle="-")
            ax.add_artist(c)

        if _INTERACTIVE:
            plt.pause(0.1)
        else:
            plt.savefig(
                os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png"))

    if _INTERACTIVE:
        plt.show()
def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, pred_cls_p,
                    pred_reg_p):
    frame_id = f"{batch_dict['frame_id'][ib]:06d}"
    sequence = batch_dict["sequence"][ib]

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    ax.set_xlim(_X_LIM[0], _X_LIM[1])
    ax.set_ylim(_Y_LIM[0], _Y_LIM[1])
    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.set_aspect("equal")
    # ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.")

    # scan and cls label
    scan_r = batch_dict["scans"][ib][-1]
    scan_phi = batch_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    ax.scatter(scan_x, scan_y, s=0.5, c="blue")

    # annotation
    ann = batch_dict["dets_wp"][ib]
    ann_valid_mask = batch_dict["anns_valid_mask"][ib]
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y, valid in zip(det_x, det_y, ann_valid_mask):
            if valid:
                # c = plt.Circle((x, y), radius=0.1, color="red", fill=True)
                c = plt.Circle((x, y),
                               radius=0.4,
                               color="red",
                               fill=False,
                               linestyle="--")
                ax.add_artist(c)

    # plot detections
    if pred_cls is not None and pred_reg is not None:
        dets_xy, dets_cls, _ = u.nms_predicted_center(scan_r, scan_phi,
                                                      pred_cls[ib].reshape(-1),
                                                      pred_reg[ib])
        dets_xy = dets_xy[dets_cls >= 0.9438938]  # at EER
        if len(dets_xy) > 0:
            for x, y in dets_xy:
                c = plt.Circle((x, y), radius=0.4, color="green", fill=False)
                ax.add_artist(c)
        fig_file = os.path.join(_SAVE_DIR,
                                f"figs/{sequence}/scan_det_{frame_id}.png")

        # plot in addition detections from a pre-trained
        if pred_cls_p is not None and pred_reg_p is not None:
            dets_xy, dets_cls, _ = u.nms_predicted_center(
                scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib])
            dets_xy = dets_xy[dets_cls > 0.29919282]  # at EER
            if len(dets_xy) > 0:
                for x, y in dets_xy:
                    c = plt.Circle((x, y),
                                   radius=0.4,
                                   color="green",
                                   fill=False)
                    ax.add_artist(c)
    # plot pre-trained detections only
    elif pred_cls_p is not None and pred_reg_p is not None:
        dets_xy, dets_cls, _ = u.nms_predicted_center(
            scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib])
        dets_xy = dets_xy[dets_cls > 0.29919282]  # at EER
        if len(dets_xy) > 0:
            for x, y in dets_xy:
                c = plt.Circle((x, y), radius=0.4, color="green", fill=False)
                ax.add_artist(c)
        fig_file = os.path.join(
            _SAVE_DIR, f"figs/{sequence}/scan_pretrain_{frame_id}.png")
    # plot pseudo-labels only
    else:
        pl_neg_mask = batch_dict["target_cls"][ib] == 0
        ax.scatter(scan_x[pl_neg_mask], scan_y[pl_neg_mask], s=0.5, c="orange")

        pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
        if len(pl_xy) > 0:
            for x, y in pl_xy:
                c = plt.Circle((x, y), radius=0.4, color="green", fill=False)
                ax.add_artist(c)
        fig_file = os.path.join(_SAVE_DIR,
                                f"figs/{sequence}/scan_pl_{frame_id}.png")

    # save fig
    os.makedirs(os.path.dirname(fig_file), exist_ok=True)
    fig.savefig(fig_file, dpi=200)
    plt.close(fig)
Beispiel #19
0
def play_sequence_with_tracking():
    # scans
    seq_name = './data/DROWv2-data/train/lunch_2015-11-26-12-04-23.bag.csv'
    seq0, seq1 = 109170, 109360
    scans, scans_t = [], []
    with open(seq_name) as f:
        for line in f:
            scan_seq, scan_t, scan = line.split(",", 2)
            scan_seq = int(scan_seq)
            if scan_seq < seq0:
                continue
            scans.append(np.fromstring(scan, sep=','))
            scans_t.append(float(scan_t))
            if scan_seq > seq1:
                break
    scans = np.stack(scans, axis=0)
    scans_t = np.array(scans_t)
    scan_phi = u.get_laser_phi()

    # odometry, used only for plotting
    odo_name = seq_name[:-3] + 'odom2'
    odos = np.genfromtxt(odo_name, delimiter=',')
    odos_t = odos[:, 1]
    odos_phi = odos[:, 4]

    # detector
    ckpt = './ckpts/dr_spaam_e40.pth'
    detector = Detector(model_name="DR-SPAAM",
                        ckpt_file=ckpt,
                        gpu=True,
                        stride=1,
                        tracking=True)
    detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)

    # scanner location
    rad_tmp = 0.5 * np.ones(len(scan_phi), dtype=np.float)
    xy_scanner = u.rphi_to_xy(rad_tmp, scan_phi)
    xy_scanner = np.stack(xy_scanner, axis=1)

    # plot
    fig = plt.figure(figsize=(6, 8))
    ax = fig.add_subplot(111)

    _break = False

    def p(event):
        nonlocal _break
        _break = True

    fig.canvas.mpl_connect('key_press_event', p)

    # video sequence
    odo_idx = 0
    for i in range(len(scans)):
        plt.cla()

        ax.set_aspect('equal')
        ax.set_xlim(-10, 5)
        ax.set_ylim(-5, 15)

        # ax.set_title('Frame: %s' % i)
        ax.set_title('Press any key to exit.')
        ax.axis("off")

        # find matching odometry
        while odo_idx < len(odos_t) - 1 and odos_t[odo_idx] < scans_t[i]:
            odo_idx += 1
        odo_phi = odos_phi[odo_idx]
        odo_rot = np.array(
            [[np.cos(odo_phi), np.sin(odo_phi)],
             [-np.sin(odo_phi), np.cos(odo_phi)]],
            dtype=np.float32)

        # plot scanner location
        xy_scanner_rot = np.matmul(xy_scanner, odo_rot.T)
        ax.plot(xy_scanner_rot[:, 0], xy_scanner_rot[:, 1], c='black')
        ax.plot((0, xy_scanner_rot[0, 0] * 1.0),
                (0, xy_scanner_rot[0, 1] * 1.0),
                c='black')
        ax.plot((0, xy_scanner_rot[-1, 0] * 1.0),
                (0, xy_scanner_rot[-1, 1] * 1.0),
                c='black')

        # plot points
        scan = scans[i]
        scan_x, scan_y = u.rphi_to_xy(scan, scan_phi + odo_phi)
        ax.scatter(scan_x, scan_y, s=1, c='blue')

        # inference
        dets_xy, dets_cls, instance_mask = detector(scan)

        # plot detection
        dets_xy_rot = np.matmul(dets_xy, odo_rot.T)
        cls_thresh = 0.3
        for j in range(len(dets_xy)):
            if dets_cls[j] < cls_thresh:
                continue
            c = plt.Circle(dets_xy_rot[j],
                           radius=0.5,
                           color='r',
                           fill=False,
                           linewidth=2)
            ax.add_artist(c)

        # plot track
        cls_thresh = 0.2
        tracks, tracks_cls = detector.get_tracklets()
        for t, tc in zip(tracks, tracks_cls):
            if tc >= cls_thresh and len(t) > 1:
                t_rot = np.matmul(t, odo_rot.T)
                ax.plot(t_rot[:, 0], t_rot[:, 1], color='g', linewidth=2)

        # plt.savefig('/home/dan/tmp/track3_img/frame_%04d.png' % i)

        plt.pause(0.001)

        if _break:
            break
Beispiel #20
0
def _plot_pseudo_labels(batch_dict, ib):
    # pseudo labels
    pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
    pl_boxes = batch_dict["pseudo_label_boxes"][ib]

    if len(pl_xy) == 0:
        return

    # groundtruth
    anns_rphi = np.array(batch_dict["dets_wp"][ib], dtype=np.float32)[
        batch_dict["anns_valid_mask"][ib]
    ]

    # match pseudo labels with groundtruth
    if len(anns_rphi) > 0:
        gts_x, gts_y = u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1])

        x_diff = pl_xy[:, 0].reshape(-1, 1) - gts_x.reshape(1, -1)
        y_diff = pl_xy[:, 1].reshape(-1, 1) - gts_y.reshape(1, -1)
        d_diff = np.sqrt(x_diff * x_diff + y_diff * y_diff)
        match_found = d_diff < 0.3  # (pl, gt)
        match_found = match_found.max(axis=1)
    else:
        match_found = np.zeros(len(pl_xy), dtype=np.bool)

    # overlay image with laser
    im = batch_dict["im_data"][ib]["stitched_image0"]
    scan_r = batch_dict["scans"][ib][-1]
    scan_phi = batch_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    scan_z = batch_dict["laser_z"][ib]
    scan_xyz_laser = np.stack((scan_x, -scan_y, scan_z), axis=0)  # in JRDB laser frame
    p_xy, ib_mask = jt.transform_pts_laser_to_stitched_im(scan_xyz_laser)
    p_xy = p_xy[:, ib_mask]

    far_v = 1.0
    far_s = 0
    close_v = 0.75
    close_s = 0.59
    dist_normalized = np.clip(scan_r[ib_mask], 0.0, 20.0) / 20.0

    c_hsv = np.empty((1, p_xy.shape[1], 3), dtype=np.float32)
    c_hsv[0, :, 0] = 0.0
    # c_hsv[0, :, 1] = 1.0 - np.clip(scan_r[ib_mask], 0.0, 20.0) / 20.0
    c_hsv[0, :, 1] = close_s * (1.0 - dist_normalized) + far_s * dist_normalized
    c_hsv[0, :, 2] = close_v * (1.0 - dist_normalized) + far_v * dist_normalized
    c_bgr = cv2.cvtColor(c_hsv, cv2.COLOR_HSV2RGB)[0]

    # plot
    frame_id = f"{batch_dict['frame_id'][ib]:06d}"
    sequence = batch_dict["sequence"][ib]

    for count, (xy, box, is_pos) in enumerate(zip(pl_xy, pl_boxes, match_found)):
        # image
        x0, y0, x1, y1 = box
        x0 = int(x0)
        x1 = int(x1)
        y0 = int(y0)
        y1 = int(y1)
        im_box = im[y0 : y1 + 1, x0 : x1 + 1]
        height = y1 - y0
        width = x1 - x0

        fig_w_inch = 0.314961 * 2.0
        fig_h_inch = 0.708661 * 2.0

        fig_im = plt.figure()
        fig_im.set_size_inches(fig_w_inch, fig_h_inch, forward=False)
        ax_im = plt.Axes(fig_im, [0.0, 0.0, 1.0, 1.0])
        ax_im.imshow(im_box)
        ax_im.set_axis_off()
        ax_im.axis(([0, width, height, 0]))
        ax_im.set_aspect((fig_h_inch / fig_w_inch) / (height / width))
        fig_im.add_axes(ax_im)

        in_box_mask = np.logical_and(
            np.logical_and(p_xy[0] >= x0, p_xy[0] <= x1),
            np.logical_and(p_xy[1] >= y0, p_xy[1] <= y1),
        )
        plt.scatter(
            p_xy[0, in_box_mask] - x0,
            p_xy[1, in_box_mask] - y0,
            s=3,
            c=c_bgr[in_box_mask],
        )

        pos_neg_dir = "true" if is_pos else "false"
        fig_file = os.path.join(
            _SAVE_DIR, f"samples/{sequence}/{pos_neg_dir}/{frame_id}_{count}_im.pdf"
        )
        os.makedirs(os.path.dirname(fig_file), exist_ok=True)
        plt.savefig(fig_file, dpi=height / fig_h_inch)
        plt.close(fig_im)

        # lidar
        plot_range = 0.5
        close_mask = np.hypot(scan_x - xy[0], scan_y - xy[1]) < plot_range

        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot()
        ax.set_aspect("equal")
        ax.axis("off")
        # ax.set_xlim(-plot_range, plot_range)
        # ax.set_ylim(-plot_range, plot_range)
        # ax.set_xlabel("x [m]")
        # ax.set_ylabel("y [m]")
        # ax.set_aspect("equal")
        # ax.set_title(f"Frame {batch_dict['idx'][ib]}")

        # plot points in local frame (so it looks aligned with image)
        ang = np.mean(scan_phi[close_mask]) - 0.5 * np.pi
        ca, sa = np.cos(ang), np.sin(ang)
        xy_plotting = np.array([[ca, sa], [-sa, ca]]) @ np.stack(
            (scan_x[close_mask] - xy[0], scan_y[close_mask] - xy[1]), axis=0
        )

        ax.scatter(
            -xy_plotting[0], xy_plotting[1], s=80, color=(191 / 255, 83 / 255, 79 / 255)
        )
        ax.scatter(
            0, 0, s=500, color=(18 / 255, 105 / 255, 176 / 255), marker="+", linewidth=5
        )

        fig_file = os.path.join(
            _SAVE_DIR, f"samples/{sequence}/{pos_neg_dir}/{frame_id}_{count}_pt.pdf"
        )
        fig.savefig(fig_file)
        plt.close(fig)
Beispiel #21
0
def play_sequence():
    # scans
    seq_name = './data/DROWv2-data/test/run_t_2015-11-26-11-22-03.bag.csv'
    # seq_name = './data/DROWv2-data/val/run_2015-11-26-15-52-55-k.bag.csv'
    scans_data = np.genfromtxt(seq_name, delimiter=',')
    scans_t = scans_data[:, 1]
    scans = scans_data[:, 2:]
    scan_phi = u.get_laser_phi()

    # odometry, used only for plotting
    odo_name = seq_name[:-3] + 'odom2'
    odos = np.genfromtxt(odo_name, delimiter=',')
    odos_t = odos[:, 1]
    odos_phi = odos[:, 4]

    # detector
    ckpt = './ckpts/dr_spaam_e40.pth'
    detector = Detector(model_name="DR-SPAAM",
                        ckpt_file=ckpt,
                        gpu=True,
                        stride=1)
    detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)

    # scanner location
    rad_tmp = 0.5 * np.ones(len(scan_phi), dtype=np.float)
    xy_scanner = u.rphi_to_xy(rad_tmp, scan_phi)
    xy_scanner = np.stack(xy_scanner, axis=1)

    # plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False

    def p(event):
        nonlocal _break
        _break = True

    fig.canvas.mpl_connect('key_press_event', p)

    # video sequence
    odo_idx = 0
    for i in range(len(scans)):
        # for i in range(0, len(scans), 20):
        plt.cla()

        ax.set_aspect('equal')
        ax.set_xlim(-15, 15)
        ax.set_ylim(-15, 15)

        # ax.set_title('Frame: %s' % i)
        ax.set_title('Press any key to exit.')
        ax.axis("off")

        # find matching odometry
        while odo_idx < len(odos_t) - 1 and odos_t[odo_idx] < scans_t[i]:
            odo_idx += 1
        odo_phi = odos_phi[odo_idx]
        odo_rot = np.array(
            [[np.cos(odo_phi), np.sin(odo_phi)],
             [-np.sin(odo_phi), np.cos(odo_phi)]],
            dtype=np.float32)

        # plot scanner location
        xy_scanner_rot = np.matmul(xy_scanner, odo_rot.T)
        ax.plot(xy_scanner_rot[:, 0], xy_scanner_rot[:, 1], c='black')
        ax.plot((0, xy_scanner_rot[0, 0] * 1.0),
                (0, xy_scanner_rot[0, 1] * 1.0),
                c='black')
        ax.plot((0, xy_scanner_rot[-1, 0] * 1.0),
                (0, xy_scanner_rot[-1, 1] * 1.0),
                c='black')

        # plot points
        scan = scans[i]
        scan_x, scan_y = u.rphi_to_xy(scan, scan_phi + odo_phi)
        ax.scatter(scan_x, scan_y, s=1, c='blue')

        # inference
        dets_xy, dets_cls, instance_mask = detector(scan)

        # plot detection
        dets_xy_rot = np.matmul(dets_xy, odo_rot.T)
        cls_thresh = 0.5
        for j in range(len(dets_xy)):
            if dets_cls[j] < cls_thresh:
                continue
            # c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False)
            c = plt.Circle(dets_xy_rot[j],
                           radius=0.5,
                           color='r',
                           fill=False,
                           linewidth=2)
            ax.add_artist(c)

        # plt.savefig('/home/dan/tmp/det_img/frame_%04d.png' % i)

        plt.pause(0.001)

        if _break:
            break
Beispiel #22
0
def _plot_frame(batch_dict, ib):
    frame_id = f"{batch_dict['frame_id'][ib]:06d}"
    sequence = batch_dict["sequence"][ib]

    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(3, 1, figure=fig)

    ax_im = fig.add_subplot(gs[0, 0])
    ax_bev = fig.add_subplot(gs[1:, 0])

    ax_bev.set_xlim(_X_LIM[0], _X_LIM[1])
    ax_bev.set_ylim(_Y_LIM[0], _Y_LIM[1])
    ax_bev.set_xlabel("x [m]")
    ax_bev.set_ylabel("y [m]")
    ax_bev.set_aspect("equal")
    ax_bev.set_title(f"Frame {batch_dict['idx'][ib]}")

    # scan and cls label
    scan_r = batch_dict["scans"][ib][-1]
    scan_phi = batch_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)

    target_cls = batch_dict["target_cls"][ib]
    ax_bev.scatter(scan_x[target_cls == -2], scan_y[target_cls == -2], s=1, c="yellow")
    ax_bev.scatter(scan_x[target_cls == -1], scan_y[target_cls == -1], s=1, c="orange")
    ax_bev.scatter(scan_x[target_cls == 0], scan_y[target_cls == 0], s=1, c="black")
    ax_bev.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=1, c="green")

    # annotation
    ann = batch_dict["dets_wp"][ib]
    ann_valid_mask = batch_dict["anns_valid_mask"][ib]
    if len(ann) > 0:
        ann = np.array(ann)
        det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1])
        for x, y, valid in zip(det_x, det_y, ann_valid_mask):
            c = "blue" if valid else "orange"
            c = plt.Circle((x, y), radius=0.4, color=c, fill=False)
            ax_bev.add_artist(c)

    # reg label
    target_reg = batch_dict["target_reg"][ib]
    dets_r, dets_phi = u.canonical_to_global(
        scan_r, scan_phi, target_reg[:, 0], target_reg[:, 1]
    )
    dets_r = dets_r[target_cls > 0]
    dets_phi = dets_phi[target_cls > 0]
    dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi)
    ax_bev.scatter(dets_x, dets_y, s=10, c="red")

    # image
    ax_im.axis("off")
    ax_im.imshow(batch_dict["im_data"][ib]["stitched_image0"])

    # detection bounding box
    for box_dict in batch_dict["im_dets"][ib]:
        x0, y0, w, h = box_dict["box"]
        x1 = x0 + w
        y1 = y0 + h
        verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1)
        c = max(float(box_dict["score"]) - 0.5, 0) * 2.0
        ax_im.plot(verts[:, 0], verts[:, 1], c=(1.0 - c, 1.0 - c, 1.0))

    for box in batch_dict["pseudo_label_boxes"][ib]:
        x0, y0, x1, y1 = box
        verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1)
        ax_im.plot(verts[:, 0], verts[:, 1], c="green")

    # laser points on image
    scan_z = batch_dict["laser_z"][ib]
    scan_xyz_laser = np.stack((scan_x, -scan_y, scan_z), axis=0)  # in JRDB laser frame

    p_xy, ib_mask = jt.transform_pts_laser_to_stitched_im(scan_xyz_laser)
    c = np.clip(scan_r, 0.0, 20.0) / 20.0
    c = c.reshape(-1, 1).repeat(3, axis=1)
    c[:, 0] = 1.0
    ax_im.scatter(p_xy[0, ib_mask], p_xy[1, ib_mask], s=1, c=c[ib_mask])

    # save fig
    fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/{frame_id}.png")
    os.makedirs(os.path.dirname(fig_file), exist_ok=True)

    fig.savefig(fig_file)
    plt.close(fig)
def _plot_frame_im(batch_dict, ib, show_pseudo_labels=False):
    frame_id = f"{batch_dict['frame_id'][ib]:06d}"
    sequence = batch_dict["sequence"][ib]

    im = batch_dict["im_data"][ib]["stitched_image0"]
    crop_min_x = 0
    im = im[:, crop_min_x:]
    height = im.shape[0]
    width = im.shape[1]
    dpi = height / 1.0

    fig = plt.figure()
    fig.set_size_inches(1.0 * width / height, 1, forward=False)
    ax_im = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    fig.add_axes(ax_im)

    # image
    ax_im.axis("off")
    ax_im.imshow(im)
    plt.xlim(0, width)
    plt.ylim(height, 0)

    # laser points on image
    scan_r = batch_dict["scans"][ib][-1]
    scan_phi = batch_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    scan_z = batch_dict["laser_z"][ib]
    scan_xyz_laser = np.stack((scan_x, -scan_y, scan_z),
                              axis=0)  # in JRDB laser frame
    p_xy, ib_mask = jt.transform_pts_laser_to_stitched_im(scan_xyz_laser)

    if show_pseudo_labels:
        # detection bounding box
        for box_dict in batch_dict["im_dets"][ib]:
            x0, y0, w, h = box_dict["box"]
            x1 = x0 + w
            y1 = y0 + h
            verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1)
            ax_im.plot(verts[:, 0] - crop_min_x,
                       verts[:, 1],
                       c=(0.0, 0.0, 1.0),
                       alpha=0.3)
            # c = max(float(box_dict["score"]) - 0.5, 0) * 2.0
            # ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1], c=(1.0 - c, 1.0 - c, 1.0))
            # ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1],
            # c=(0.0, 0.0, 1.0), alpha=1.0)

            # x1_large = x1 + 0.05 * w
            # x0_large = x0 - 0.05 * w
            # y1_large = y1 + 0.05 * w
            # y0_large = y0 - 0.05 * w
            # in_box_mask = np.logical_and(
            #     np.logical_and(p_xy[0] > x0_large, p_xy[0] < x1_large),
            #     np.logical_and(p_xy[1] > y0_large, p_xy[1] < y1_large)
            # )
            # neg_mask[in_box_mask] = False

        for box in batch_dict["pseudo_label_boxes"][ib]:
            x0, y0, x1, y1 = box
            verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1)
            ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1], c="green")

        # overlay only pseudo-label laser points on image
        pl_pos_mask = np.logical_and(batch_dict["target_cls"][ib] == 1,
                                     ib_mask)
        pl_neg_mask = np.logical_and(batch_dict["target_cls"][ib] == 0,
                                     ib_mask)
        ax_im.scatter(
            p_xy[0, pl_pos_mask] - crop_min_x,
            p_xy[1, pl_pos_mask],
            s=1,
            color="green",
        )
        ax_im.scatter(
            p_xy[0, pl_neg_mask] - crop_min_x,
            p_xy[1, pl_neg_mask],
            s=1,
            color="orange",
        )

        fig_file = os.path.join(_SAVE_DIR,
                                f"figs/{sequence}/im_pl_{frame_id}.png")
    else:
        # overlay all laser points on image
        c_bgr = _distance_to_bgr_color(scan_r)
        ax_im.scatter(p_xy[0, ib_mask] - crop_min_x,
                      p_xy[1, ib_mask],
                      s=1,
                      color=c_bgr[ib_mask])
        fig_file = os.path.join(_SAVE_DIR,
                                f"figs/{sequence}/im_raw_{frame_id}.png")

    # save fig
    os.makedirs(os.path.dirname(fig_file), exist_ok=True)
    fig.savefig(fig_file, dpi=dpi)
    plt.close(fig)
def _plot_sequence():
    jrdb_handle = JRDBHandle(
        split="train",
        cfg={
            "data_dir": "./data/JRDB",
            "num_scans": 10,
            "scan_stride": 1
        },
    )

    color_pool = np.random.uniform(size=(100, 3))

    for i, data_dict in enumerate(jrdb_handle):
        # lidar
        pc_xyz_upper = jt.transform_pts_upper_velodyne_to_base(
            data_dict["pc_data"]["upper_velodyne"])
        pc_xyz_lower = jt.transform_pts_lower_velodyne_to_base(
            data_dict["pc_data"]["lower_velodyne"])

        # laser
        laser_r = data_dict["laser_data"][-1]
        laser_phi = data_dict["laser_grid"]
        laser_z = data_dict["laser_z"]
        laser_x, laser_y = u.rphi_to_xy(laser_r, laser_phi)
        pc_xyz_laser = jt.transform_pts_laser_to_base(
            np.stack((laser_x, laser_y, laser_z), axis=0))

        if _COLOR_INSTANCE:
            # labels
            boxes, label_ids = [], []
            for ann in data_dict["pc_anns"]:
                # jrdb_handle.box_is_on_ground(ann)
                box, b_id = ub3d.box_from_jrdb(ann)
                boxes.append(box)
                label_ids.append(b_id)
            boxes = np.array(boxes)  # (B, 7)
            pc = np.concatenate([pc_xyz_laser, pc_xyz_upper, pc_xyz_lower],
                                axis=1)
            in_box_mask, closest_box_inds = ub3d.associate_points_and_boxes(
                pc, boxes, resize_factor=1.0)

            # plot bg points
            bg_pc = pc[:, np.logical_not(in_box_mask)]
            mlab.points3d(
                bg_pc[0],
                bg_pc[1],
                bg_pc[2],
                scale_factor=0.05,
                color=(1.0, 0.0, 0.0),
            )

            # plot box and fg points
            fg_pc = pc[:, in_box_mask]
            fg_box_inds = closest_box_inds[in_box_mask]
            corners_xyz, connect_inds = ub3d.boxes_to_corners(
                boxes, rtn_connect_inds=True)
            for box_idx, (p_id,
                          corner_xyz) in enumerate(zip(label_ids,
                                                       corners_xyz)):
                color = tuple(color_pool[p_id % 100])
                # box
                for inds in connect_inds:
                    mlab.plot3d(
                        corner_xyz[0, inds],
                        corner_xyz[1, inds],
                        corner_xyz[2, inds],
                        tube_radius=None,
                        line_width=5,
                        color=color,
                    )

                # point
                in_box_pc = fg_pc[:, fg_box_inds == box_idx]
                mlab.points3d(
                    in_box_pc[0],
                    in_box_pc[1],
                    in_box_pc[2],
                    scale_factor=0.05,
                    color=color,
                )

        else:
            # plot points
            mlab.points3d(
                pc_xyz_lower[0],
                pc_xyz_lower[1],
                pc_xyz_lower[2],
                scale_factor=0.05,
                color=(0.0, 1.0, 0.0),
            )
            mlab.points3d(
                pc_xyz_upper[0],
                pc_xyz_upper[1],
                pc_xyz_upper[2],
                scale_factor=0.05,
                color=(0.0, 0.0, 1.0),
            )
            mlab.points3d(
                pc_xyz_laser[0],
                pc_xyz_laser[1],
                pc_xyz_laser[2],
                scale_factor=0.05,
                color=(1.0, 0.0, 0.0),
            )

            # plot box
            boxes = []
            for ann in data_dict["pc_anns"]:
                # jrdb_handle.box_is_on_ground(ann)
                box = ub3d.box_from_jrdb(ann, fast_mode=False)
                corners_xyz, connect_inds = box.to_corners(
                    resize_factor=1.0, rtn_connect_inds=True)
                for inds in connect_inds:
                    mlab.plot3d(
                        corners_xyz[0, inds],
                        corners_xyz[1, inds],
                        corners_xyz[2, inds],
                        tube_radius=None,
                        line_width=5,
                        color=tuple(color_pool[box.get_id() % 100]),
                    )

        mlab.show()