예제 #1
0
파일: dataset.py 프로젝트: MaQuanZhi/RITnet
    def __getitem__(self, idx):
        imagepath = osp.join(self.filepath,'images',self.list_files[idx]+'.png')
        pilimg = Image.open(imagepath).convert("L")
        H, W = pilimg.width , pilimg.height
       
        #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS 
        #Fixed gamma value for      
        table = 255.0*(np.linspace(0, 1, 256)**0.8)
        pilimg = cv2.LUT(np.array(pilimg), table)
        

        if self.split != 'test':
            labelpath = osp.join(self.filepath,'labels',self.list_files[idx]+'.npy')
            label = np.load(labelpath)    
            label = np.resize(label,(W,H))
            label = Image.fromarray(label)     
               
        if self.transform is not None:
            if self.split == 'train':
                if random.random() < 0.2: 
                    pilimg = Starburst_augment()(np.array(pilimg))  
                if random.random() < 0.2: 
                    pilimg = Line_augment()(np.array(pilimg))    
                if random.random() < 0.2:
                    pilimg = Gaussian_blur()(np.array(pilimg))   
                if random.random() < 0.4:
                    pilimg, label = Translation()(np.array(pilimg),np.array(label))
                
        img = self.clahe.apply(np.array(np.uint8(pilimg)))    
        img = Image.fromarray(img)      
            
        if self.transform is not None:
            if self.split == 'train':
                img, label = RandomHorizontalFlip()(img,label)
            img = self.transform(img)    


        if self.split != 'test':
            ## This is for boundary aware cross entropy calculation
            spatialWeights = cv2.Canny(np.array(label),0,3)/255
            spatialWeights=cv2.dilate(spatialWeights,(3,3),iterations = 1)*20
            
            ##This is the implementation for the surface loss
            # Distance map for each class
            distMap = []
            for i in range(0, 4):
                distMap.append(one_hot2dist(np.array(label)==i))
            distMap = np.stack(distMap, 0)           
#            spatialWeights=np.float32(distMap) 
            
            
        if self.split == 'test':
            ##since label, spatialWeights and distMap is not needed for test images
            return img,0,self.list_files[idx],0,0
            
        label = MaskToTensor()(label)
        return img, label, self.list_files[idx],spatialWeights,np.float32(distMap) 
예제 #2
0
    def test_empty(self):
        a = np.zeros((1, 256, 256))

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual(res[1].sum(), 0)
        self.assertEqual((res[0] <= 0).sum(), a.size)
예제 #3
0
    def test_full_coverage(self):
        a = np.zeros((1, 256, 256))
        a[:, 50:60, :] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual((res[1] <= 0).sum(), a.sum())
        self.assertEqual((res[1] > 0).sum(), (1 - a).sum())
예제 #4
0
    def test_closure(self):
        a = np.zeros((1, 256, 256))
        a[:, 50:60, :] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        neg = (res <= 0) * res

        self.assertEqual(neg.sum(), (o * res).sum())
예제 #5
0
    def test_border(self):
        """
        Make sure the border inside the object is 0 in the distance map
        """

        for l in range(3, 5):
            a = np.zeros((1, 25, 25))
            a[:, 3:3 + l, 3:3 + l] = 1

            o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                    C=2).numpy()
            res = utils.one_hot2dist(o[0])
            self.assertEqual(res.shape, (2, 25, 25))

            border = (res[1] == 0)

            self.assertEqual(border.sum(), 4 * (l - 1))
예제 #6
0
    def test_max_dist(self):
        """
        The max dist for a box should be at the midle of the object, +-1
        """
        a = np.zeros((1, 256, 256))
        a[:, 1:254, 1:254] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual(res[0].max(), 127)
        self.assertEqual(np.unravel_index(res[0].argmax(), (256, 256)),
                         (127, 127))

        self.assertEqual(res[1].min(), -126)
        self.assertEqual(np.unravel_index(res[1].argmin(), (256, 256)),
                         (127, 127))
예제 #7
0
def save_slices(
        ct_paths,
        cbf_paths,
        cbv_paths,
        mtt_paths,
        tmax_paths,
        gt_paths,
        dest_dir: Path,
        shape: Tuple[int],
        n_augment: int,
        ct_dir: str = "ct",
        cbf_dir="cbf",
        cbv_dir="cbv",
        mtt_dir="mtt",
        tmax_dir="tmax",
        gt_dir: str = "gt",
        in_npy_dir="in_npy",
        gt_npy_dir='gt_npy'
) -> Tuple[int, int, Dict, Tuple[float, float, float]]:
    p_id: str = get_p_id(ct_paths)
    assert len(
        set(
            map_(get_p_id, [
                ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths
            ]))) == 1

    space_dict: Dict[str, Tuple[float, float]] = {}

    # Load the data
    dx, dy, dz = nib.load(str(ct_paths)).header.get_zooms()
    assert dx == dy
    ct = np.asarray(nib.load(str(ct_paths)).dataobj)
    cbf = np.asarray(nib.load(str(cbf_paths)).dataobj)
    cbv = np.asarray(nib.load(str(cbv_paths)).dataobj)
    mtt = np.asarray(nib.load(str(mtt_paths)).dataobj)
    tmax = np.asarray(nib.load(str(tmax_paths)).dataobj)
    gt = np.asarray(nib.load(str(gt_paths)).dataobj)

    assert len(set(map_(np.shape, [ct, cbf, cbv, mtt, tmax, gt]))) == 1
    assert ct.dtype in [np.int32], ct.dtype
    assert cbf.dtype in [np.uint16], cbf.dtype
    assert cbv.dtype in [np.uint16], cbv.dtype
    assert mtt.dtype in [np.float64], mtt.dtype
    assert tmax.dtype in [np.float64], tmax.dtype
    assert gt.dtype in [np.uint8], gt.dtype

    pos: int = (gt == 1).sum()
    neg: int = (gt == 0).sum()

    x, y, z = ct.shape

    # Normalize and check data content
    norm_ct = norm_arr(
        ct)  # We need to normalize the whole 3d img, not 2d slices
    norm_cbf = norm_arr(cbf)
    norm_cbv = norm_arr(cbv)
    norm_mtt = norm_arr(mtt)
    norm_tmax = norm_arr(tmax)
    assert 0 == norm_ct.min() and norm_ct.max() == 1, (norm_ct.min(),
                                                       norm_ct.max())
    assert 0 == norm_cbf.min() and norm_cbf.max() == 1, (norm_cbf.min(),
                                                         norm_cbf.max())
    assert 0 == norm_cbv.min() and norm_cbv.max() == 1, (norm_cbv.min(),
                                                         norm_cbv.max())
    assert 0 == norm_mtt.min() and norm_mtt.max() == 1, (norm_mtt.min(),
                                                         norm_mtt.max())
    assert 0 == norm_tmax.min() and norm_tmax.max() == 1, (norm_tmax.min(),
                                                           norm_tmax.max())

    one_hot_gt: Tensor = class2one_hot(torch.tensor(gt[None, ...],
                                                    dtype=torch.int64),
                                       K=2)[0]
    assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape
    distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(),
                                       resolution=(dx, dy, dz),
                                       dtype=np.float32)

    save_dir_ct: Path = Path(dest_dir, ct_dir)
    save_dir_cbf: Path = Path(dest_dir, cbf_dir)
    save_dir_cbv: Path = Path(dest_dir, cbv_dir)
    save_dir_mtt: Path = Path(dest_dir, mtt_dir)
    save_dir_tmax: Path = Path(dest_dir, tmax_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_dir_in_npy: Path = Path(dest_dir, in_npy_dir)
    save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir)
    save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap")
    save_dirs = [
        save_dir_ct, save_dir_cbf, save_dir_cbv, save_dir_mtt, save_dir_tmax,
        save_dir_gt
    ]

    for j in range(ct.shape[-1]):
        ct_s = norm_ct[:, :, j]
        cbf_s = norm_cbf[:, :, j]
        cbv_s = norm_cbv[:, :, j]
        mtt_s = norm_mtt[:, :, j]
        tmax_s = norm_tmax[:, :, j]
        gt_s = gt[:, :, j]
        dist_s = distmap[:, :, :, j]
        slices = [ct_s, cbf_s, cbv_s, mtt_s, tmax_s, gt_s]
        assert ct_s.shape == cbf_s.shape == cbv_s.shape, mtt_s.shape == tmax_s.shape == gt_s.shape
        assert gt_s.shape == dist_s[0, ...].shape, ((x, y, z), gt_s.shape,
                                                    dist_s.shape)
        assert set(np.unique(gt_s)).issubset([0, 1])

        # if gt_s.sum() > 0:
        #     print(f"{dist_s[1].min()=} {dist_s[1].max()=}")
        #     _, axes = plt.subplots(nrows=1, ncols=3)
        #     axes[0].imshow(gt_s)
        #     axes[0].set_title("GT")

        #     tmp = axes[1].imshow(dist_s[1, ...], cmap='rainbow')
        #     axes[1].set_title("Signed distance map")
        #     plt.colorbar(tmp, ax=axes[1])

        #     tmp = axes[2].imshow(np.abs(dist_s[1, ...]), cmap='rainbow')
        #     axes[2].set_title("Abs distance map")
        #     plt.colorbar(tmp, ax=axes[2])
        #     plt.show()

        for k in range(n_augment + 1):
            if k == 0:
                to_save = slices
            else:
                to_save = map_(np.asarray, augment_arr(*slices))
                assert to_save[0].shape == slices[0].shape, (to_save[0].shape,
                                                             slices[0].shape)

            filename = f"{p_id}_{k:02d}_{j:04d}"
            space_dict[filename] = (dx, dy)
            for save_dir, data in zip(save_dirs, to_save):
                save_dir.mkdir(parents=True, exist_ok=True)

                if "gt" not in str(save_dir):
                    img = (data * 255).astype(np.uint8)
                else:
                    img = data.astype(np.uint8)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename).with_suffix(".png")),
                           img)

            multimodal = np.stack(
                to_save[:-1])  # Do not include the ground truth
            assert 0 <= multimodal.min() and multimodal.max() <= 1
            save_dir_in_npy.mkdir(parents=True, exist_ok=True)
            save_dir_gt_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_in_npy, filename).with_suffix(".npy"),
                multimodal)
            np.save(
                Path(save_dir_gt_npy, filename).with_suffix(".npy"),
                to_save[-1])

            save_dir_distmap_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_distmap_npy, filename).with_suffix(".npy"),
                dist_s)

    return neg, pos, space_dict, (dx, dy, dz)
예제 #8
0
def save_slices(
        flair_path,
        t1_path,
        gt_path,
        dest_dir: Path,
        shape: Tuple[int],
        n_augment: int,
        discard_negatives: bool,
        flair_dir: str = "flair",
        t1_dir="t1",
        gt_dir: str = "gt",
        in_npy_dir="in_npy",
        gt_npy_dir='gt_npy'
) -> Tuple[int, int, Dict, Tuple[float, float, float]]:
    p_id: str = get_p_id(flair_path)
    assert len(set(map_(get_p_id, [flair_path, t1_path, gt_path]))) == 1
    print(p_id)

    space_dict: Dict[str, Tuple[float, float]] = {}

    # Load the data
    dx, dy, dz = nib.load(str(flair_path)).header.get_zooms()
    # assert dx == dy, (dx, dy)
    flair = np.asarray(nib.load(str(flair_path)).dataobj)
    w, h, _ = flair.shape
    x, y, z = flair.shape
    t1 = np.asarray(nib.load(str(t1_path)).dataobj)
    gt = np.asarray(nib.load(str(gt_path)).dataobj)
    assert set(np.unique(gt)) <= set([0., 1., 2.])

    pos: int = (gt == 1).sum()
    neg: int = ((gt == 0) | (gt == 2)).sum()

    assert len(set(map_(np.shape, [flair, t1, gt]))) == 1
    assert flair.dtype in [np.float32], flair.dtype
    assert t1.dtype in [np.uint16], t1.dtype
    assert gt.dtype in [np.float32], gt.dtype

    # Normalize and check data content
    norm_flair = norm_arr(
        flair)  # We need to normalize the whole 3d img, not 2d slices
    norm_t1 = norm_arr(t1)
    norm_gt = gt.astype(np.uint8)
    assert 0 == norm_flair.min() and norm_flair.max() == 1, (norm_flair.min(),
                                                             norm_flair.max())
    assert 0 == norm_t1.min() and norm_t1.max() == 1, (norm_t1.min(),
                                                       norm_t1.max())
    assert np.array_equal(np.unique(gt), np.unique(norm_gt))

    resized_flair = resize(norm_flair, (256, 256, z),
                           mode='constant',
                           preserve_range=True,
                           anti_aliasing=False).astype(np.float32)
    resized_t1 = resize(norm_t1, (256, 256, z),
                        mode='constant',
                        preserve_range=True,
                        anti_aliasing=False).astype(np.float32)
    resized_gt = resize(norm_gt, (256, 256, z),
                        mode='constant',
                        preserve_range=True,
                        anti_aliasing=False,
                        order=0).astype(np.uint8)
    resized_gt[np.where(
        resized_gt == 2)] = 0  # Count those labels as background

    # Pre-compute the 3d distance map
    rx = dx * w / 256
    ry = dy * h / 256
    rz = dz
    # print(f"{flair.shape=}")
    # print(f"{(dx,dy,dz)=} {(rx,ry,rz)=}")

    one_hot_gt: Tensor = class2one_hot(torch.tensor(resized_gt[None, ...],
                                                    dtype=torch.int64),
                                       K=2)[0]
    assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape
    distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(),
                                       resolution=(rx, ry, rz),
                                       dtype=np.float32)

    save_dir_flair: Path = Path(dest_dir, flair_dir)
    save_dir_t1: Path = Path(dest_dir, t1_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_dir_in_npy: Path = Path(dest_dir, in_npy_dir)
    save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir)
    save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap")
    save_dirs = [save_dir_flair, save_dir_t1, save_dir_gt]

    for j in range(flair.shape[-1]):
        flair_s = resized_flair[:, :, j]
        t1_s = resized_t1[:, :, j]
        gt_s = resized_gt[:, :, j]

        dist_s = distmap[:, :, :, j]
        # if gt_s.sum() > 0:
        #     print(f"{dist_s.min()=} {dist_s.max()=}")
        #     _, axes = plt.subplots(nrows=1, ncols=2)
        #     axes[0].imshow(gt_s)
        #     axes[0].set_title("GT")

        #     tmp = axes[1].imshow(dist_s[1, ...])
        #     axes[1].set_title("Distance map")
        #     plt.colorbar(tmp, ax=axes[1])
        #     plt.show()

        slices = [flair_s, t1_s, gt_s]
        assert flair_s.shape == t1_s.shape == gt_s.shape == dist_s[
            0, ...].shape, ((x, y, z), flair_s.shape, dist_s.shape)
        # gt_s[np.where(gt_s == 2)] = 0  # Now do that part earlier
        assert set(np.unique(gt_s)).issubset([0, 1]), np.unique(gt_s)

        if discard_negatives and (gt_s.sum() == 0):
            continue

        for k in range(n_augment + 1):
            if k == 0:
                to_save = slices
            else:
                to_save = map_(np.asarray, augment_arr(*slices))
                assert to_save[0].shape == slices[0].shape, (to_save[0].shape,
                                                             slices[0].shape)

            filename = f"{p_id}_{k:02d}_{j:04d}"
            space_dict[filename] = (rx, ry)
            for save_dir, data in zip(save_dirs, to_save):
                save_dir.mkdir(parents=True, exist_ok=True)

                if "gt" not in str(save_dir):
                    img = (data * 255).astype(np.uint8)
                else:
                    img = data.astype(np.uint8)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename).with_suffix(".png")),
                           img)

            multimodal = np.stack(
                to_save[:-1])  # Do not include the ground truth
            assert 0 <= multimodal.min(), multimodal.min()
            assert multimodal.max() <= 1, multimodal.max()
            save_dir_in_npy.mkdir(parents=True, exist_ok=True)
            save_dir_gt_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_in_npy, filename).with_suffix(".npy"),
                multimodal)
            np.save(
                Path(save_dir_gt_npy, filename).with_suffix(".npy"),
                to_save[-1])

            save_dir_distmap_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_distmap_npy, filename).with_suffix(".npy"),
                dist_s)

    return neg, pos, space_dict, (rx, ry, rz)