def __getitem__(self, idx): imagepath = osp.join(self.filepath,'images',self.list_files[idx]+'.png') pilimg = Image.open(imagepath).convert("L") H, W = pilimg.width , pilimg.height #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS #Fixed gamma value for table = 255.0*(np.linspace(0, 1, 256)**0.8) pilimg = cv2.LUT(np.array(pilimg), table) if self.split != 'test': labelpath = osp.join(self.filepath,'labels',self.list_files[idx]+'.npy') label = np.load(labelpath) label = np.resize(label,(W,H)) label = Image.fromarray(label) if self.transform is not None: if self.split == 'train': if random.random() < 0.2: pilimg = Starburst_augment()(np.array(pilimg)) if random.random() < 0.2: pilimg = Line_augment()(np.array(pilimg)) if random.random() < 0.2: pilimg = Gaussian_blur()(np.array(pilimg)) if random.random() < 0.4: pilimg, label = Translation()(np.array(pilimg),np.array(label)) img = self.clahe.apply(np.array(np.uint8(pilimg))) img = Image.fromarray(img) if self.transform is not None: if self.split == 'train': img, label = RandomHorizontalFlip()(img,label) img = self.transform(img) if self.split != 'test': ## This is for boundary aware cross entropy calculation spatialWeights = cv2.Canny(np.array(label),0,3)/255 spatialWeights=cv2.dilate(spatialWeights,(3,3),iterations = 1)*20 ##This is the implementation for the surface loss # Distance map for each class distMap = [] for i in range(0, 4): distMap.append(one_hot2dist(np.array(label)==i)) distMap = np.stack(distMap, 0) # spatialWeights=np.float32(distMap) if self.split == 'test': ##since label, spatialWeights and distMap is not needed for test images return img,0,self.list_files[idx],0,0 label = MaskToTensor()(label) return img, label, self.list_files[idx],spatialWeights,np.float32(distMap)
def test_empty(self): a = np.zeros((1, 256, 256)) o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy() res = utils.one_hot2dist(o[0]) self.assertEqual(res.shape, (2, 256, 256)) self.assertEqual(res[1].sum(), 0) self.assertEqual((res[0] <= 0).sum(), a.size)
def test_full_coverage(self): a = np.zeros((1, 256, 256)) a[:, 50:60, :] = 1 o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy() res = utils.one_hot2dist(o[0]) self.assertEqual(res.shape, (2, 256, 256)) self.assertEqual((res[1] <= 0).sum(), a.sum()) self.assertEqual((res[1] > 0).sum(), (1 - a).sum())
def test_closure(self): a = np.zeros((1, 256, 256)) a[:, 50:60, :] = 1 o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy() res = utils.one_hot2dist(o[0]) self.assertEqual(res.shape, (2, 256, 256)) neg = (res <= 0) * res self.assertEqual(neg.sum(), (o * res).sum())
def test_border(self): """ Make sure the border inside the object is 0 in the distance map """ for l in range(3, 5): a = np.zeros((1, 25, 25)) a[:, 3:3 + l, 3:3 + l] = 1 o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy() res = utils.one_hot2dist(o[0]) self.assertEqual(res.shape, (2, 25, 25)) border = (res[1] == 0) self.assertEqual(border.sum(), 4 * (l - 1))
def test_max_dist(self): """ The max dist for a box should be at the midle of the object, +-1 """ a = np.zeros((1, 256, 256)) a[:, 1:254, 1:254] = 1 o = utils.class2one_hot(torch.Tensor(a).type(torch.float32), C=2).numpy() res = utils.one_hot2dist(o[0]) self.assertEqual(res.shape, (2, 256, 256)) self.assertEqual(res[0].max(), 127) self.assertEqual(np.unravel_index(res[0].argmax(), (256, 256)), (127, 127)) self.assertEqual(res[1].min(), -126) self.assertEqual(np.unravel_index(res[1].argmin(), (256, 256)), (127, 127))
def save_slices( ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths, dest_dir: Path, shape: Tuple[int], n_augment: int, ct_dir: str = "ct", cbf_dir="cbf", cbv_dir="cbv", mtt_dir="mtt", tmax_dir="tmax", gt_dir: str = "gt", in_npy_dir="in_npy", gt_npy_dir='gt_npy' ) -> Tuple[int, int, Dict, Tuple[float, float, float]]: p_id: str = get_p_id(ct_paths) assert len( set( map_(get_p_id, [ ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths ]))) == 1 space_dict: Dict[str, Tuple[float, float]] = {} # Load the data dx, dy, dz = nib.load(str(ct_paths)).header.get_zooms() assert dx == dy ct = np.asarray(nib.load(str(ct_paths)).dataobj) cbf = np.asarray(nib.load(str(cbf_paths)).dataobj) cbv = np.asarray(nib.load(str(cbv_paths)).dataobj) mtt = np.asarray(nib.load(str(mtt_paths)).dataobj) tmax = np.asarray(nib.load(str(tmax_paths)).dataobj) gt = np.asarray(nib.load(str(gt_paths)).dataobj) assert len(set(map_(np.shape, [ct, cbf, cbv, mtt, tmax, gt]))) == 1 assert ct.dtype in [np.int32], ct.dtype assert cbf.dtype in [np.uint16], cbf.dtype assert cbv.dtype in [np.uint16], cbv.dtype assert mtt.dtype in [np.float64], mtt.dtype assert tmax.dtype in [np.float64], tmax.dtype assert gt.dtype in [np.uint8], gt.dtype pos: int = (gt == 1).sum() neg: int = (gt == 0).sum() x, y, z = ct.shape # Normalize and check data content norm_ct = norm_arr( ct) # We need to normalize the whole 3d img, not 2d slices norm_cbf = norm_arr(cbf) norm_cbv = norm_arr(cbv) norm_mtt = norm_arr(mtt) norm_tmax = norm_arr(tmax) assert 0 == norm_ct.min() and norm_ct.max() == 1, (norm_ct.min(), norm_ct.max()) assert 0 == norm_cbf.min() and norm_cbf.max() == 1, (norm_cbf.min(), norm_cbf.max()) assert 0 == norm_cbv.min() and norm_cbv.max() == 1, (norm_cbv.min(), norm_cbv.max()) assert 0 == norm_mtt.min() and norm_mtt.max() == 1, (norm_mtt.min(), norm_mtt.max()) assert 0 == norm_tmax.min() and norm_tmax.max() == 1, (norm_tmax.min(), norm_tmax.max()) one_hot_gt: Tensor = class2one_hot(torch.tensor(gt[None, ...], dtype=torch.int64), K=2)[0] assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(), resolution=(dx, dy, dz), dtype=np.float32) save_dir_ct: Path = Path(dest_dir, ct_dir) save_dir_cbf: Path = Path(dest_dir, cbf_dir) save_dir_cbv: Path = Path(dest_dir, cbv_dir) save_dir_mtt: Path = Path(dest_dir, mtt_dir) save_dir_tmax: Path = Path(dest_dir, tmax_dir) save_dir_gt: Path = Path(dest_dir, gt_dir) save_dir_in_npy: Path = Path(dest_dir, in_npy_dir) save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir) save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap") save_dirs = [ save_dir_ct, save_dir_cbf, save_dir_cbv, save_dir_mtt, save_dir_tmax, save_dir_gt ] for j in range(ct.shape[-1]): ct_s = norm_ct[:, :, j] cbf_s = norm_cbf[:, :, j] cbv_s = norm_cbv[:, :, j] mtt_s = norm_mtt[:, :, j] tmax_s = norm_tmax[:, :, j] gt_s = gt[:, :, j] dist_s = distmap[:, :, :, j] slices = [ct_s, cbf_s, cbv_s, mtt_s, tmax_s, gt_s] assert ct_s.shape == cbf_s.shape == cbv_s.shape, mtt_s.shape == tmax_s.shape == gt_s.shape assert gt_s.shape == dist_s[0, ...].shape, ((x, y, z), gt_s.shape, dist_s.shape) assert set(np.unique(gt_s)).issubset([0, 1]) # if gt_s.sum() > 0: # print(f"{dist_s[1].min()=} {dist_s[1].max()=}") # _, axes = plt.subplots(nrows=1, ncols=3) # axes[0].imshow(gt_s) # axes[0].set_title("GT") # tmp = axes[1].imshow(dist_s[1, ...], cmap='rainbow') # axes[1].set_title("Signed distance map") # plt.colorbar(tmp, ax=axes[1]) # tmp = axes[2].imshow(np.abs(dist_s[1, ...]), cmap='rainbow') # axes[2].set_title("Abs distance map") # plt.colorbar(tmp, ax=axes[2]) # plt.show() for k in range(n_augment + 1): if k == 0: to_save = slices else: to_save = map_(np.asarray, augment_arr(*slices)) assert to_save[0].shape == slices[0].shape, (to_save[0].shape, slices[0].shape) filename = f"{p_id}_{k:02d}_{j:04d}" space_dict[filename] = (dx, dy) for save_dir, data in zip(save_dirs, to_save): save_dir.mkdir(parents=True, exist_ok=True) if "gt" not in str(save_dir): img = (data * 255).astype(np.uint8) else: img = data.astype(np.uint8) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) imsave(str(Path(save_dir, filename).with_suffix(".png")), img) multimodal = np.stack( to_save[:-1]) # Do not include the ground truth assert 0 <= multimodal.min() and multimodal.max() <= 1 save_dir_in_npy.mkdir(parents=True, exist_ok=True) save_dir_gt_npy.mkdir(parents=True, exist_ok=True) np.save( Path(save_dir_in_npy, filename).with_suffix(".npy"), multimodal) np.save( Path(save_dir_gt_npy, filename).with_suffix(".npy"), to_save[-1]) save_dir_distmap_npy.mkdir(parents=True, exist_ok=True) np.save( Path(save_dir_distmap_npy, filename).with_suffix(".npy"), dist_s) return neg, pos, space_dict, (dx, dy, dz)
def save_slices( flair_path, t1_path, gt_path, dest_dir: Path, shape: Tuple[int], n_augment: int, discard_negatives: bool, flair_dir: str = "flair", t1_dir="t1", gt_dir: str = "gt", in_npy_dir="in_npy", gt_npy_dir='gt_npy' ) -> Tuple[int, int, Dict, Tuple[float, float, float]]: p_id: str = get_p_id(flair_path) assert len(set(map_(get_p_id, [flair_path, t1_path, gt_path]))) == 1 print(p_id) space_dict: Dict[str, Tuple[float, float]] = {} # Load the data dx, dy, dz = nib.load(str(flair_path)).header.get_zooms() # assert dx == dy, (dx, dy) flair = np.asarray(nib.load(str(flair_path)).dataobj) w, h, _ = flair.shape x, y, z = flair.shape t1 = np.asarray(nib.load(str(t1_path)).dataobj) gt = np.asarray(nib.load(str(gt_path)).dataobj) assert set(np.unique(gt)) <= set([0., 1., 2.]) pos: int = (gt == 1).sum() neg: int = ((gt == 0) | (gt == 2)).sum() assert len(set(map_(np.shape, [flair, t1, gt]))) == 1 assert flair.dtype in [np.float32], flair.dtype assert t1.dtype in [np.uint16], t1.dtype assert gt.dtype in [np.float32], gt.dtype # Normalize and check data content norm_flair = norm_arr( flair) # We need to normalize the whole 3d img, not 2d slices norm_t1 = norm_arr(t1) norm_gt = gt.astype(np.uint8) assert 0 == norm_flair.min() and norm_flair.max() == 1, (norm_flair.min(), norm_flair.max()) assert 0 == norm_t1.min() and norm_t1.max() == 1, (norm_t1.min(), norm_t1.max()) assert np.array_equal(np.unique(gt), np.unique(norm_gt)) resized_flair = resize(norm_flair, (256, 256, z), mode='constant', preserve_range=True, anti_aliasing=False).astype(np.float32) resized_t1 = resize(norm_t1, (256, 256, z), mode='constant', preserve_range=True, anti_aliasing=False).astype(np.float32) resized_gt = resize(norm_gt, (256, 256, z), mode='constant', preserve_range=True, anti_aliasing=False, order=0).astype(np.uint8) resized_gt[np.where( resized_gt == 2)] = 0 # Count those labels as background # Pre-compute the 3d distance map rx = dx * w / 256 ry = dy * h / 256 rz = dz # print(f"{flair.shape=}") # print(f"{(dx,dy,dz)=} {(rx,ry,rz)=}") one_hot_gt: Tensor = class2one_hot(torch.tensor(resized_gt[None, ...], dtype=torch.int64), K=2)[0] assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(), resolution=(rx, ry, rz), dtype=np.float32) save_dir_flair: Path = Path(dest_dir, flair_dir) save_dir_t1: Path = Path(dest_dir, t1_dir) save_dir_gt: Path = Path(dest_dir, gt_dir) save_dir_in_npy: Path = Path(dest_dir, in_npy_dir) save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir) save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap") save_dirs = [save_dir_flair, save_dir_t1, save_dir_gt] for j in range(flair.shape[-1]): flair_s = resized_flair[:, :, j] t1_s = resized_t1[:, :, j] gt_s = resized_gt[:, :, j] dist_s = distmap[:, :, :, j] # if gt_s.sum() > 0: # print(f"{dist_s.min()=} {dist_s.max()=}") # _, axes = plt.subplots(nrows=1, ncols=2) # axes[0].imshow(gt_s) # axes[0].set_title("GT") # tmp = axes[1].imshow(dist_s[1, ...]) # axes[1].set_title("Distance map") # plt.colorbar(tmp, ax=axes[1]) # plt.show() slices = [flair_s, t1_s, gt_s] assert flair_s.shape == t1_s.shape == gt_s.shape == dist_s[ 0, ...].shape, ((x, y, z), flair_s.shape, dist_s.shape) # gt_s[np.where(gt_s == 2)] = 0 # Now do that part earlier assert set(np.unique(gt_s)).issubset([0, 1]), np.unique(gt_s) if discard_negatives and (gt_s.sum() == 0): continue for k in range(n_augment + 1): if k == 0: to_save = slices else: to_save = map_(np.asarray, augment_arr(*slices)) assert to_save[0].shape == slices[0].shape, (to_save[0].shape, slices[0].shape) filename = f"{p_id}_{k:02d}_{j:04d}" space_dict[filename] = (rx, ry) for save_dir, data in zip(save_dirs, to_save): save_dir.mkdir(parents=True, exist_ok=True) if "gt" not in str(save_dir): img = (data * 255).astype(np.uint8) else: img = data.astype(np.uint8) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) imsave(str(Path(save_dir, filename).with_suffix(".png")), img) multimodal = np.stack( to_save[:-1]) # Do not include the ground truth assert 0 <= multimodal.min(), multimodal.min() assert multimodal.max() <= 1, multimodal.max() save_dir_in_npy.mkdir(parents=True, exist_ok=True) save_dir_gt_npy.mkdir(parents=True, exist_ok=True) np.save( Path(save_dir_in_npy, filename).with_suffix(".npy"), multimodal) np.save( Path(save_dir_gt_npy, filename).with_suffix(".npy"), to_save[-1]) save_dir_distmap_npy.mkdir(parents=True, exist_ok=True) np.save( Path(save_dir_distmap_npy, filename).with_suffix(".npy"), dist_s) return neg, pos, space_dict, (rx, ry, rz)