def plot_pseudo_label_for_all_frames():
    with open("./cfgs/base_drow_jrdb_cfg.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg["dataset"]["pseudo_label"] = True
    cfg["dataset"]["pl_correction_level"] = 0

    test_loader = get_dataloader(
        split=_SPLIT,
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )

    model = get_model(cfg["model"])
    model.cuda()
    model.eval()

    logger = Logger(cfg["pipeline"]["Logger"])
    logger.load_ckpt("./ckpts/ckpt_jrdb_pl_drow3_phce_e40.pth", model)

    model_pretrain = get_model(cfg["model"])
    model_pretrain.cuda()
    model_pretrain.eval()
    logger.load_ckpt("./ckpts/ckpt_drow_drow3_e40.pth", model_pretrain)

    # generate pseudo labels for all sample
    seq_count = 0
    for count, batch_dict in enumerate(tqdm(test_loader)):
        if batch_dict["first_frame"][0]:
            print(f"new seq, reset count, idx {count}")
            seq_count = 0

        if seq_count > _SEQ_MAX_COUNT:
            continue
        else:
            seq_count += 1

        if count >= _MAX_COUNT:
            break

        with torch.no_grad():
            net_input = torch.from_numpy(batch_dict["input"]).cuda().float()
            pred_cls, pred_reg = model(net_input)
            pred_cls = torch.sigmoid(pred_cls).data.cpu().numpy()
            pred_reg = pred_reg.data.cpu().numpy()

            pred_cls_p, pred_reg_p = model_pretrain(net_input)
            pred_cls_p = torch.sigmoid(pred_cls_p).data.cpu().numpy()
            pred_reg_p = pred_reg_p.data.cpu().numpy()

        if count % _PLOTTING_INTERVAL == 0:
            for ib in range(len(batch_dict["input"])):
                # generate sequence videos
                _plot_frame_im(batch_dict, ib, False)  # image
                _plot_frame_im(batch_dict, ib, True)  # image
                _plot_frame_pts(batch_dict, ib, None, None, None,
                                None)  # pseudo-label
                _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None,
                                None)  # detections
Ejemplo n.º 2
0
def generate_pseudo_labels():
    with open("./cfgs/base_dr_spaam_jrdb_cfg.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg["dataset"]["pseudo_label"] = True
    cfg["dataset"]["pl_correction_level"] = 0

    test_loader = get_dataloader(
        split=_SPLIT,
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )

    if os.path.exists(_SAVE_DIR):
        shutil.rmtree(_SAVE_DIR)
        time.sleep(1.0)
    os.makedirs(_SAVE_DIR)

    # generate pseudo labels for all sample
    for count, batch_dict in enumerate(tqdm(test_loader)):
        if count >= _MAX_COUNT:
            break

        for ib in range(len(batch_dict["input"])):
            if count % _PLOTTING_INTERVAL == 0:
                # # visualize the whole frame
                # _plot_frame(batch_dict, ib)

                # visualize each pseudo labels
                _plot_pseudo_labels(batch_dict, ib)
Ejemplo n.º 3
0
def run_evaluation(model, pipeline, cfg):
    val_loader = get_dataloader(
        split="val",
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )
    pipeline.evaluate(model, val_loader, tb_prefix="VAL")

    test_loader = get_dataloader(
        split="test",
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )
    pipeline.evaluate(model, test_loader, tb_prefix="TEST")
Ejemplo n.º 4
0
def run_training(model, pipeline, cfg):
    # main train loop
    train_loader = get_dataloader(split="train",
                                  shuffle=True,
                                  dataset_cfg=cfg["dataset"],
                                  **cfg["dataloader"])
    val_loader = get_dataloader(split="val",
                                shuffle=True,
                                dataset_cfg=cfg["dataset"],
                                **cfg["dataloader"])
    status = pipeline.train(model, train_loader, val_loader)

    # test after training
    if not status:
        test_loader = get_dataloader(
            split="test",
            batch_size=1,
            num_workers=1,
            shuffle=False,
            dataset_cfg=cfg["dataset"],
        )
        pipeline.evaluate(model, test_loader, tb_prefix="TEST")
Ejemplo n.º 5
0
def _test_dataloader():
    with open("./base_dr_spaam_jrdb_cfg.yaml", "r") as f:
        cfg = yaml.safe_load(f)

    cfg["dataset"]["pseudo_label"] = False
    cfg["dataset"]["pl_correction_level"] = 0

    test_loader = get_dataloader(
        split="val",
        batch_size=5,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    _break = False

    if _INTERACTIVE:

        def p(event):
            nonlocal _break
            _break = True

        fig.canvas.mpl_connect("key_press_event", p)
    else:
        if os.path.exists(_SAVE_DIR):
            shutil.rmtree(_SAVE_DIR)
        os.makedirs(_SAVE_DIR)

    for count, data_dict in enumerate(test_loader):
        if count >= _MAX_COUNT:
            break

        for ib in range(len(data_dict["input"])):
            _plot_sample(fig, ax, ib, count, data_dict)

            if _INTERACTIVE:
                plt.pause(0.1)
            else:
                plt.savefig(
                    os.path.join(
                        _SAVE_DIR,
                        f"b{count:03}s{ib:02}f{data_dict['idx'][ib]:04}.pdf"))

    if _INTERACTIVE:
        plt.show()
def generate_pseudo_labels():
    with open("./cfgs/base_dr_spaam_jrdb_cfg.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg["dataset"]["pseudo_label"] = True
    cfg["dataset"]["pl_correction_level"] = 0

    test_loader = get_dataloader(
        split=_SPLIT,
        batch_size=1,
        num_workers=1,
        shuffle=False,
        dataset_cfg=cfg["dataset"],
    )

    if os.path.exists(_SAVE_DIR):
        shutil.rmtree(_SAVE_DIR)
        time.sleep(1.0)
    os.makedirs(_SAVE_DIR)

    sequences_tp_fp_tn_fn = {}  # for computing true positive and negative rate

    # generate pseudo labels for all sample
    for count, batch_dict in enumerate(tqdm(test_loader)):
        if count >= _MAX_COUNT:
            break

        for ib in range(len(batch_dict["input"])):
            frame_id = f"{batch_dict['frame_id'][ib]:06d}"
            sequence = batch_dict["sequence"][ib]

            if count % _PLOTTING_INTERVAL == 0:
                # visualize the whole frame
                _plot_frame(batch_dict, ib)

                # visualize each pseudo labels
                _plot_pseudo_labels(batch_dict, ib)

            # save pseudo labels as detection results for evaluation
            pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
            pl_str = (pru.drow_detection_to_kitti_string(pl_xy, None, None)
                      if len(pl_xy) > 0 else "")

            pl_file = os.path.join(_SAVE_DIR,
                                   f"detections/{sequence}/{frame_id}.txt")
            _write_file_make_dir(pl_file, pl_str)

            # save groundtruth
            anns_rphi = batch_dict["dets_wp"][ib]
            if len(anns_rphi) > 0:
                anns_rphi = np.array(anns_rphi, dtype=np.float32)
                gts_xy = np.stack(u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:,
                                                                          1]),
                                  axis=1)
                gts_occluded = np.logical_not(
                    batch_dict["anns_valid_mask"][ib]).astype(np.int)
                gts_str = pru.drow_detection_to_kitti_string(
                    gts_xy, None, gts_occluded)
            else:
                gts_str = ""

            gts_file = os.path.join(_SAVE_DIR,
                                    f"groundtruth/{sequence}/{frame_id}.txt")
            _write_file_make_dir(gts_file, gts_str)

            # compute true positive and negative rate
            target_cls = batch_dict["target_cls"][ib]
            target_cls_gt = batch_dict["target_cls_real"][ib]

            tn = np.sum(
                np.logical_or(
                    np.logical_and(target_cls == 0, target_cls_gt == 0),
                    np.logical_and(target_cls == 0, target_cls_gt == -1),
                ))
            fn = np.sum(target_cls == 0) - tn
            tp = np.sum(
                np.logical_or(
                    np.logical_and(target_cls == 1, target_cls_gt == 1),
                    np.logical_and(target_cls == 1, target_cls_gt == -1),
                ))
            fp = np.sum(target_cls == 1) - tp

            if sequence in sequences_tp_fp_tn_fn.keys():
                tp0, fp0, tn0, fn0 = sequences_tp_fp_tn_fn[sequence]
                sequences_tp_fp_tn_fn[sequence] = (
                    tp + tp0,
                    fp + fp0,
                    tn + tn0,
                    fn + fn0,
                )
            else:
                sequences_tp_fp_tn_fn[sequence] = (tp, fp, tn, fn)

    # write sequence statistics to file
    for sequence, (tp, fp, tn, fn) in sequences_tp_fp_tn_fn.items():
        st_file = os.path.join(_SAVE_DIR,
                               f"evaluation/{sequence}/tp_fp_tn_fn.txt")
        _write_file_make_dir(st_file, f"{tp},{fp},{tn},{fn}")