def plot_pseudo_label_for_all_frames(): with open("./cfgs/base_drow_jrdb_cfg.yaml", "r") as f: cfg = yaml.safe_load(f) cfg["dataset"]["pseudo_label"] = True cfg["dataset"]["pl_correction_level"] = 0 test_loader = get_dataloader( split=_SPLIT, batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) model = get_model(cfg["model"]) model.cuda() model.eval() logger = Logger(cfg["pipeline"]["Logger"]) logger.load_ckpt("./ckpts/ckpt_jrdb_pl_drow3_phce_e40.pth", model) model_pretrain = get_model(cfg["model"]) model_pretrain.cuda() model_pretrain.eval() logger.load_ckpt("./ckpts/ckpt_drow_drow3_e40.pth", model_pretrain) # generate pseudo labels for all sample seq_count = 0 for count, batch_dict in enumerate(tqdm(test_loader)): if batch_dict["first_frame"][0]: print(f"new seq, reset count, idx {count}") seq_count = 0 if seq_count > _SEQ_MAX_COUNT: continue else: seq_count += 1 if count >= _MAX_COUNT: break with torch.no_grad(): net_input = torch.from_numpy(batch_dict["input"]).cuda().float() pred_cls, pred_reg = model(net_input) pred_cls = torch.sigmoid(pred_cls).data.cpu().numpy() pred_reg = pred_reg.data.cpu().numpy() pred_cls_p, pred_reg_p = model_pretrain(net_input) pred_cls_p = torch.sigmoid(pred_cls_p).data.cpu().numpy() pred_reg_p = pred_reg_p.data.cpu().numpy() if count % _PLOTTING_INTERVAL == 0: for ib in range(len(batch_dict["input"])): # generate sequence videos _plot_frame_im(batch_dict, ib, False) # image _plot_frame_im(batch_dict, ib, True) # image _plot_frame_pts(batch_dict, ib, None, None, None, None) # pseudo-label _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None, None) # detections
def generate_pseudo_labels(): with open("./cfgs/base_dr_spaam_jrdb_cfg.yaml", "r") as f: cfg = yaml.safe_load(f) cfg["dataset"]["pseudo_label"] = True cfg["dataset"]["pl_correction_level"] = 0 test_loader = get_dataloader( split=_SPLIT, batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) if os.path.exists(_SAVE_DIR): shutil.rmtree(_SAVE_DIR) time.sleep(1.0) os.makedirs(_SAVE_DIR) # generate pseudo labels for all sample for count, batch_dict in enumerate(tqdm(test_loader)): if count >= _MAX_COUNT: break for ib in range(len(batch_dict["input"])): if count % _PLOTTING_INTERVAL == 0: # # visualize the whole frame # _plot_frame(batch_dict, ib) # visualize each pseudo labels _plot_pseudo_labels(batch_dict, ib)
def run_evaluation(model, pipeline, cfg): val_loader = get_dataloader( split="val", batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) pipeline.evaluate(model, val_loader, tb_prefix="VAL") test_loader = get_dataloader( split="test", batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) pipeline.evaluate(model, test_loader, tb_prefix="TEST")
def run_training(model, pipeline, cfg): # main train loop train_loader = get_dataloader(split="train", shuffle=True, dataset_cfg=cfg["dataset"], **cfg["dataloader"]) val_loader = get_dataloader(split="val", shuffle=True, dataset_cfg=cfg["dataset"], **cfg["dataloader"]) status = pipeline.train(model, train_loader, val_loader) # test after training if not status: test_loader = get_dataloader( split="test", batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) pipeline.evaluate(model, test_loader, tb_prefix="TEST")
def _test_dataloader(): with open("./base_dr_spaam_jrdb_cfg.yaml", "r") as f: cfg = yaml.safe_load(f) cfg["dataset"]["pseudo_label"] = False cfg["dataset"]["pl_correction_level"] = 0 test_loader = get_dataloader( split="val", batch_size=5, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) _break = False if _INTERACTIVE: def p(event): nonlocal _break _break = True fig.canvas.mpl_connect("key_press_event", p) else: if os.path.exists(_SAVE_DIR): shutil.rmtree(_SAVE_DIR) os.makedirs(_SAVE_DIR) for count, data_dict in enumerate(test_loader): if count >= _MAX_COUNT: break for ib in range(len(data_dict["input"])): _plot_sample(fig, ax, ib, count, data_dict) if _INTERACTIVE: plt.pause(0.1) else: plt.savefig( os.path.join( _SAVE_DIR, f"b{count:03}s{ib:02}f{data_dict['idx'][ib]:04}.pdf")) if _INTERACTIVE: plt.show()
def generate_pseudo_labels(): with open("./cfgs/base_dr_spaam_jrdb_cfg.yaml", "r") as f: cfg = yaml.safe_load(f) cfg["dataset"]["pseudo_label"] = True cfg["dataset"]["pl_correction_level"] = 0 test_loader = get_dataloader( split=_SPLIT, batch_size=1, num_workers=1, shuffle=False, dataset_cfg=cfg["dataset"], ) if os.path.exists(_SAVE_DIR): shutil.rmtree(_SAVE_DIR) time.sleep(1.0) os.makedirs(_SAVE_DIR) sequences_tp_fp_tn_fn = {} # for computing true positive and negative rate # generate pseudo labels for all sample for count, batch_dict in enumerate(tqdm(test_loader)): if count >= _MAX_COUNT: break for ib in range(len(batch_dict["input"])): frame_id = f"{batch_dict['frame_id'][ib]:06d}" sequence = batch_dict["sequence"][ib] if count % _PLOTTING_INTERVAL == 0: # visualize the whole frame _plot_frame(batch_dict, ib) # visualize each pseudo labels _plot_pseudo_labels(batch_dict, ib) # save pseudo labels as detection results for evaluation pl_xy = batch_dict["pseudo_label_loc_xy"][ib] pl_str = (pru.drow_detection_to_kitti_string(pl_xy, None, None) if len(pl_xy) > 0 else "") pl_file = os.path.join(_SAVE_DIR, f"detections/{sequence}/{frame_id}.txt") _write_file_make_dir(pl_file, pl_str) # save groundtruth anns_rphi = batch_dict["dets_wp"][ib] if len(anns_rphi) > 0: anns_rphi = np.array(anns_rphi, dtype=np.float32) gts_xy = np.stack(u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1]), axis=1) gts_occluded = np.logical_not( batch_dict["anns_valid_mask"][ib]).astype(np.int) gts_str = pru.drow_detection_to_kitti_string( gts_xy, None, gts_occluded) else: gts_str = "" gts_file = os.path.join(_SAVE_DIR, f"groundtruth/{sequence}/{frame_id}.txt") _write_file_make_dir(gts_file, gts_str) # compute true positive and negative rate target_cls = batch_dict["target_cls"][ib] target_cls_gt = batch_dict["target_cls_real"][ib] tn = np.sum( np.logical_or( np.logical_and(target_cls == 0, target_cls_gt == 0), np.logical_and(target_cls == 0, target_cls_gt == -1), )) fn = np.sum(target_cls == 0) - tn tp = np.sum( np.logical_or( np.logical_and(target_cls == 1, target_cls_gt == 1), np.logical_and(target_cls == 1, target_cls_gt == -1), )) fp = np.sum(target_cls == 1) - tp if sequence in sequences_tp_fp_tn_fn.keys(): tp0, fp0, tn0, fn0 = sequences_tp_fp_tn_fn[sequence] sequences_tp_fp_tn_fn[sequence] = ( tp + tp0, fp + fp0, tn + tn0, fn + fn0, ) else: sequences_tp_fp_tn_fn[sequence] = (tp, fp, tn, fn) # write sequence statistics to file for sequence, (tp, fp, tn, fn) in sequences_tp_fp_tn_fn.items(): st_file = os.path.join(_SAVE_DIR, f"evaluation/{sequence}/tp_fp_tn_fn.txt") _write_file_make_dir(st_file, f"{tp},{fp},{tn},{fn}")