def setup_postprocessor(CONFIG): # CRF post-processor postprocessor = DenseCRF( iter_max=CONFIG.CRF.ITER_MAX, pos_xy_std=CONFIG.CRF.POS_XY_STD, pos_w=CONFIG.CRF.POS_W, bi_xy_std=CONFIG.CRF.BI_XY_STD, bi_rgb_std=CONFIG.CRF.BI_RGB_STD, bi_w=CONFIG.CRF.BI_W, ) return postprocessor
def crf(config_path, n_jobs): """ CRF post-processing on pre-computed logits """ # Configuration CONFIG = Dict(yaml.load(config_path)) torch.set_grad_enabled(False) print("# jobs:", n_jobs) # Dataset dataset = get_dataset(CONFIG.DATASET.NAME)( root=CONFIG.DATASET.ROOT, split=CONFIG.DATASET.SPLIT.VAL, ignore_label=CONFIG.DATASET.IGNORE_LABEL, mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), augment=False, ) print(dataset) # CRF post-processor postprocessor = DenseCRF( iter_max=CONFIG.CRF.ITER_MAX, pos_xy_std=CONFIG.CRF.POS_XY_STD, pos_w=CONFIG.CRF.POS_W, bi_xy_std=CONFIG.CRF.BI_XY_STD, bi_rgb_std=CONFIG.CRF.BI_RGB_STD, bi_w=CONFIG.CRF.BI_W, ) # Path to logit files logit_dir = os.path.join( CONFIG.EXP.OUTPUT_DIR, "features", CONFIG.EXP.ID, CONFIG.MODEL.NAME.lower(), CONFIG.DATASET.SPLIT.VAL, "logit", ) print("Logit src:", logit_dir) if not os.path.isdir(logit_dir): print("Logit not found, run first: python main.py test [OPTIONS]") quit() # Path to save scores save_dir = os.path.join( CONFIG.EXP.OUTPUT_DIR, "scores", CONFIG.EXP.ID, CONFIG.MODEL.NAME.lower(), CONFIG.DATASET.SPLIT.VAL, ) makedirs(save_dir) save_path = os.path.join(save_dir, "scores_crf.json") print("Score dst:", save_path) # Process per sample def process(i): image_id, image, gt_label = dataset.__getitem__(i) filename = os.path.join(logit_dir, image_id + ".npy") logit = np.load(filename) _, H, W = image.shape logit = torch.FloatTensor(logit)[None, ...] logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False) prob = F.softmax(logit, dim=1)[0].numpy() image = image.astype(np.uint8).transpose(1, 2, 0) prob = postprocessor(image, prob) label = np.argmax(prob, axis=0) return label, gt_label # CRF in multi-process results = joblib.Parallel(n_jobs=n_jobs, verbose=10, pre_dispatch="all")( [joblib.delayed(process)(i) for i in range(len(dataset))]) preds, gts = zip(*results) # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES) with open(save_path, "w") as f: json.dump(score, f, indent=4, sort_keys=True)
def test(config, model_path, cuda, crf): # Disable autograd globally torch.set_grad_enabled(False) # Setup device = get_device(cuda) CONFIG = Dict(yaml.load(open(config))) # If the image size never change, if CONFIG.DATASET.WARP_IMAGE: # Auto-tune cuDNN torch.backends.cudnn.benchmark = True # Dataset 10k or 164k dataset = get_dataset(CONFIG.DATASET.NAME)( root=CONFIG.DATASET.ROOT, split=CONFIG.DATASET.SPLIT.VAL, base_size=CONFIG.IMAGE.SIZE.TEST, crop_size=None, mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), warp=CONFIG.DATASET.WARP_IMAGE, scale=None, flip=False, ) # DataLoader loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST, num_workers=CONFIG.DATALOADER.NUM_WORKERS, shuffle=False, ) # Model model = setup_model(model_path, CONFIG.DATASET.N_CLASSES, train=False) model.to(device) # CRF post-processor postprocessor = DenseCRF( iter_max=CONFIG.CRF.ITER_MAX, pos_xy_std=CONFIG.CRF.POS_XY_STD, pos_w=CONFIG.CRF.POS_W, bi_xy_std=CONFIG.CRF.BI_XY_STD, bi_rgb_std=CONFIG.CRF.BI_RGB_STD, bi_w=CONFIG.CRF.BI_W, ) preds, gts = [], [] for images, labels in tqdm(loader, total=len(loader), leave=False, dynamic_ncols=True): # Image images = images.to(device) _, H, W = labels.shape # Forward propagation logits = model(images) logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=True) probs = F.softmax(logits, dim=1) probs = probs.data.cpu().numpy() # Postprocessing if crf: # images: (B,C,H,W) -> (B,H,W,C) images = images.data.cpu().numpy().astype(np.uint8).transpose( 0, 2, 3, 1) probs = joblib.Parallel(n_jobs=-1)([ joblib.delayed(postprocessor)(*pair) for pair in zip(images, probs) ]) labelmaps = np.argmax(probs, axis=1) preds += list(labelmaps) gts += list(labels.numpy()) # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES) with open(model_path.replace(".pth", ".json"), "w") as f: json.dump(score, f, indent=4, sort_keys=True)