Beispiel #1
0
def dice3d(base_folder, folder, subfoldername, grp_regex, gt_folder, C):
    if base_folder == '':
        work_folder = Path(folder, subfoldername)
    else:
        work_folder = Path(base_folder,folder, subfoldername)
    filenames = map_(lambda p: str(p.name), work_folder.glob("*.png"))
    grouping_regex: Pattern = re.compile(grp_regex)

    stems: List[str] = [Path(filename).stem for filename in filenames]  # avoid matching the extension
    matches: List[Match] = map_(grouping_regex.match, stems)
    patients: List[str] = [match.group(0) for match in matches]

    unique_patients: List[str] = list(set(patients))
    batch_dice = torch.zeros((len(unique_patients), C))
    for i, patient in enumerate(unique_patients):
        patient_slices = [f for f in stems if f.startswith(patient)]
        w,h = [256,256]
        n = len(patient_slices)
        t_seg = np.ndarray(shape=(w, h, n))
        t_gt = np.ndarray(shape=(w, h, n))
        for slice in patient_slices:
            slice_nb = int(re.split(grp_regex, slice)[1])
            seg = imageio.imread(str(work_folder)+'/'+slice+'.png')
            gt = imageio.imread(str(gt_folder )+'/'+ slice+'.png')
            if seg.shape != (w, h):
                seg = resize_im(seg, 36)
            if gt.shape != (w, h):
                gt = resize_im(gt, 36)
            seg[seg == 255] = 1
            t_seg[:, :, slice_nb] = seg
            t_gt[:, :, slice_nb] = gt
        t_seg = torch.from_numpy(t_seg)
        t_gt = torch.from_numpy(t_gt)
        batch_dice[i,...] = dice_batch(class2one_hot(t_seg,3), class2one_hot(t_gt,3))[0] # do not save the interclasses etcetc
    return batch_dice.mean(dim=0), batch_dice.std(dim=0)
Beispiel #2
0
    def test_empty(self):
        t = torch.zeros(1, 100, 100)
        t[0, 40:60, 40:60] = 1

        c = utils.class2one_hot(t, C=2)

        self.assertEqual(utils.dice_coef(c, c)[0, 0], 1)
Beispiel #3
0
    def test_caca(self):
        t = torch.zeros(1, 100, 100)
        t[0, 40:60, 40:60] = 1

        c = utils.class2one_hot(t, C=2)
        z = torch.zeros_like(c)
        z[0, 1, ...] = 1

        self.assertEqual(utils.dice_coef(c, z, smooth=0)[0, 0],
                         0)  # Annoying to deal with the almost equal thing
Beispiel #4
0
    def test_empty(self):
        a = np.zeros((1, 256, 256))

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual(res[1].sum(), 0)
        self.assertEqual((res[0] <= 0).sum(), a.size)
Beispiel #5
0
    def test_full_coverage(self):
        a = np.zeros((1, 256, 256))
        a[:, 50:60, :] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual((res[1] <= 0).sum(), a.sum())
        self.assertEqual((res[1] > 0).sum(), (1 - a).sum())
Beispiel #6
0
    def test_closure(self):
        a = np.zeros((1, 256, 256))
        a[:, 50:60, :] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        neg = (res <= 0) * res

        self.assertEqual(neg.sum(), (o * res).sum())
Beispiel #7
0
    def test_border(self):
        """
        Make sure the border inside the object is 0 in the distance map
        """

        for l in range(3, 5):
            a = np.zeros((1, 25, 25))
            a[:, 3:3 + l, 3:3 + l] = 1

            o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                    C=2).numpy()
            res = utils.one_hot2dist(o[0])
            self.assertEqual(res.shape, (2, 25, 25))

            border = (res[1] == 0)

            self.assertEqual(border.sum(), 4 * (l - 1))
Beispiel #8
0
    def test_max_dist(self):
        """
        The max dist for a box should be at the midle of the object, +-1
        """
        a = np.zeros((1, 256, 256))
        a[:, 1:254, 1:254] = 1

        o = utils.class2one_hot(torch.Tensor(a).type(torch.float32),
                                C=2).numpy()
        res = utils.one_hot2dist(o[0])
        self.assertEqual(res.shape, (2, 256, 256))

        self.assertEqual(res[0].max(), 127)
        self.assertEqual(np.unravel_index(res[0].argmax(), (256, 256)),
                         (127, 127))

        self.assertEqual(res[1].min(), -126)
        self.assertEqual(np.unravel_index(res[1].argmin(), (256, 256)),
                         (127, 127))
Beispiel #9
0

for _,_,files in os.walk(os.path.join(root, 'in_npy')): 
    for file in files: 
        image = np.load(os.path.join(root,'in_npy', file))
        gt = np.load(os.path.join(root,'gt_npy', file))
        
        #print('infering {} of shape {} and classes {}, max {} and min {} '.format( file, image.shape, np.unique(gt), image.max(), image.min()))
        image = image.reshape(-1, 1, 256, 256)
        image = torch.tensor(image, dtype=torch.float)
        image = Variable(image, requires_grad=True)
        pred = net(image)
        pred = F.softmax(pred, dim=1)
        predicted_output = probs2one_hot(pred.detach())
        #np.save(os.path.join(path, 'predictions', '{}'.format(file)), pred.detach().numpy())
        dice = dice_coef(predicted_output, class2one_hot(torch.tensor(gt), n_classes))
        hauss = haussdorf(predicted_output, class2one_hot(torch.tensor(gt), n_classes))
        
        #pred_label = len(np.unique(label(np.array(pred.argmax(axis = 1).detach().numpy()))))
        #gt_label = len(np.unique(label(gt)))
        pred_label = len(np.unique(label(predicted_output[0][1])))
        gt_label = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][1])))
        error = np.abs(pred_label - gt_label)
        pred_label2 = len(np.unique(label(predicted_output[0][2])))
        gt_label2 = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][2])))
        error2 = np.abs(pred_label2 - gt_label2)
        
        print(f"{file}, {np.float(dice[0][1])}, {np.float(hauss[0][1])},{np.float(error)} \n")
        print(f"{file}, {np.float(dice[0][2])}, {np.float(hauss[0][2])},{np.float(error2)} \n")

        fold_all_H1.write(f"{file}, {np.float(dice[0][1])}, {np.float(hauss[0][1])},{np.float(error)} \n")
Beispiel #10
0
from utils import dice_coef, dice_batch, save_images, tqdm_, haussdorf, probs2one_hot, class2one_hot, numpy_haussdorf
root = '/media/eljurros/Transcend/Decathlone/Task02_Heart/nifty/FOLD_4/npy/val'
#print(net)
fieldnames = ['SLICE_ID', 'dice', 'haus', 'c_error']
folder_path = Path(root, 'characteristics')
n_classes = 2
folder_path.mkdir(parents=True, exist_ok=True)
fold_clean_H1 = open(os.path.join(folder_path, 'characteristics.csv'), "w")
fold_clean_H1.write(f"file, size, cc, \n")

for _, _, files in os.walk(os.path.join(root, 'gt_npy')):

    print('walking into', os.path.join(root, 'gt_npy'))
    for file in files:
        #print(file)
        image = np.load(os.path.join(root, 'in_npy', file))
        gt = np.load(os.path.join(root, 'gt_npy', file))

        gt_label = len(np.unique(label(gt)))
        gt_label = len(
            np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][1])))

        print(
            f"{file}, {np.float(gt_label)}, {np.float(gt_label)},{np.float(gt_label)} \n"
        )

        fold_clean_H1.write(
            f"{file}, {np.float(gt_label)}, {np.float(gt_label)},{np.float(gt_label)} \n"
        )
        fold_clean_H1.flush()
Beispiel #11
0
def save_slices(
        ct_paths,
        cbf_paths,
        cbv_paths,
        mtt_paths,
        tmax_paths,
        gt_paths,
        dest_dir: Path,
        shape: Tuple[int],
        n_augment: int,
        ct_dir: str = "ct",
        cbf_dir="cbf",
        cbv_dir="cbv",
        mtt_dir="mtt",
        tmax_dir="tmax",
        gt_dir: str = "gt",
        in_npy_dir="in_npy",
        gt_npy_dir='gt_npy'
) -> Tuple[int, int, Dict, Tuple[float, float, float]]:
    p_id: str = get_p_id(ct_paths)
    assert len(
        set(
            map_(get_p_id, [
                ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths
            ]))) == 1

    space_dict: Dict[str, Tuple[float, float]] = {}

    # Load the data
    dx, dy, dz = nib.load(str(ct_paths)).header.get_zooms()
    assert dx == dy
    ct = np.asarray(nib.load(str(ct_paths)).dataobj)
    cbf = np.asarray(nib.load(str(cbf_paths)).dataobj)
    cbv = np.asarray(nib.load(str(cbv_paths)).dataobj)
    mtt = np.asarray(nib.load(str(mtt_paths)).dataobj)
    tmax = np.asarray(nib.load(str(tmax_paths)).dataobj)
    gt = np.asarray(nib.load(str(gt_paths)).dataobj)

    assert len(set(map_(np.shape, [ct, cbf, cbv, mtt, tmax, gt]))) == 1
    assert ct.dtype in [np.int32], ct.dtype
    assert cbf.dtype in [np.uint16], cbf.dtype
    assert cbv.dtype in [np.uint16], cbv.dtype
    assert mtt.dtype in [np.float64], mtt.dtype
    assert tmax.dtype in [np.float64], tmax.dtype
    assert gt.dtype in [np.uint8], gt.dtype

    pos: int = (gt == 1).sum()
    neg: int = (gt == 0).sum()

    x, y, z = ct.shape

    # Normalize and check data content
    norm_ct = norm_arr(
        ct)  # We need to normalize the whole 3d img, not 2d slices
    norm_cbf = norm_arr(cbf)
    norm_cbv = norm_arr(cbv)
    norm_mtt = norm_arr(mtt)
    norm_tmax = norm_arr(tmax)
    assert 0 == norm_ct.min() and norm_ct.max() == 1, (norm_ct.min(),
                                                       norm_ct.max())
    assert 0 == norm_cbf.min() and norm_cbf.max() == 1, (norm_cbf.min(),
                                                         norm_cbf.max())
    assert 0 == norm_cbv.min() and norm_cbv.max() == 1, (norm_cbv.min(),
                                                         norm_cbv.max())
    assert 0 == norm_mtt.min() and norm_mtt.max() == 1, (norm_mtt.min(),
                                                         norm_mtt.max())
    assert 0 == norm_tmax.min() and norm_tmax.max() == 1, (norm_tmax.min(),
                                                           norm_tmax.max())

    one_hot_gt: Tensor = class2one_hot(torch.tensor(gt[None, ...],
                                                    dtype=torch.int64),
                                       K=2)[0]
    assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape
    distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(),
                                       resolution=(dx, dy, dz),
                                       dtype=np.float32)

    save_dir_ct: Path = Path(dest_dir, ct_dir)
    save_dir_cbf: Path = Path(dest_dir, cbf_dir)
    save_dir_cbv: Path = Path(dest_dir, cbv_dir)
    save_dir_mtt: Path = Path(dest_dir, mtt_dir)
    save_dir_tmax: Path = Path(dest_dir, tmax_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_dir_in_npy: Path = Path(dest_dir, in_npy_dir)
    save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir)
    save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap")
    save_dirs = [
        save_dir_ct, save_dir_cbf, save_dir_cbv, save_dir_mtt, save_dir_tmax,
        save_dir_gt
    ]

    for j in range(ct.shape[-1]):
        ct_s = norm_ct[:, :, j]
        cbf_s = norm_cbf[:, :, j]
        cbv_s = norm_cbv[:, :, j]
        mtt_s = norm_mtt[:, :, j]
        tmax_s = norm_tmax[:, :, j]
        gt_s = gt[:, :, j]
        dist_s = distmap[:, :, :, j]
        slices = [ct_s, cbf_s, cbv_s, mtt_s, tmax_s, gt_s]
        assert ct_s.shape == cbf_s.shape == cbv_s.shape, mtt_s.shape == tmax_s.shape == gt_s.shape
        assert gt_s.shape == dist_s[0, ...].shape, ((x, y, z), gt_s.shape,
                                                    dist_s.shape)
        assert set(np.unique(gt_s)).issubset([0, 1])

        # if gt_s.sum() > 0:
        #     print(f"{dist_s[1].min()=} {dist_s[1].max()=}")
        #     _, axes = plt.subplots(nrows=1, ncols=3)
        #     axes[0].imshow(gt_s)
        #     axes[0].set_title("GT")

        #     tmp = axes[1].imshow(dist_s[1, ...], cmap='rainbow')
        #     axes[1].set_title("Signed distance map")
        #     plt.colorbar(tmp, ax=axes[1])

        #     tmp = axes[2].imshow(np.abs(dist_s[1, ...]), cmap='rainbow')
        #     axes[2].set_title("Abs distance map")
        #     plt.colorbar(tmp, ax=axes[2])
        #     plt.show()

        for k in range(n_augment + 1):
            if k == 0:
                to_save = slices
            else:
                to_save = map_(np.asarray, augment_arr(*slices))
                assert to_save[0].shape == slices[0].shape, (to_save[0].shape,
                                                             slices[0].shape)

            filename = f"{p_id}_{k:02d}_{j:04d}"
            space_dict[filename] = (dx, dy)
            for save_dir, data in zip(save_dirs, to_save):
                save_dir.mkdir(parents=True, exist_ok=True)

                if "gt" not in str(save_dir):
                    img = (data * 255).astype(np.uint8)
                else:
                    img = data.astype(np.uint8)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename).with_suffix(".png")),
                           img)

            multimodal = np.stack(
                to_save[:-1])  # Do not include the ground truth
            assert 0 <= multimodal.min() and multimodal.max() <= 1
            save_dir_in_npy.mkdir(parents=True, exist_ok=True)
            save_dir_gt_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_in_npy, filename).with_suffix(".npy"),
                multimodal)
            np.save(
                Path(save_dir_gt_npy, filename).with_suffix(".npy"),
                to_save[-1])

            save_dir_distmap_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_distmap_npy, filename).with_suffix(".npy"),
                dist_s)

    return neg, pos, space_dict, (dx, dy, dz)
Beispiel #12
0
        image = np.load(os.path.join(root, 'in_npy', file))
        gt = np.load(os.path.join(root, 'gt_npy', file))
        if len(np.unique(gt)) > 0:
            #print('infering {} of shape {} and classes {}, max {} and min {} '.format( file, image.shape, np.unique(gt), image.max(), image.min()))
            image = image.reshape(-1, 1, 256, 256) / 255.00
            image = torch.tensor(image, dtype=torch.float)
            image = Variable(image, requires_grad=True)
            pred = net(image)
            pred = F.softmax(pred, dim=1).to('cpu')
            predicted_output = probs2one_hot(pred.detach())
            #print(predicted_output.to('cpu')[:,:2:].shape,class2one_hot(torch.tensor(gt).to('cpu'), n_classes).shape )
            #np.save(os.path.join(path, 'predictions', '{}'.format(file)), pred.to('cpu').detach().numpy())
            #dice = dice_coef(predicted_output.to('cpu'), class2one_hot(torch.tensor(gt).to('cpu'), n_classes))[:,n,]
            dice = dice_coef(
                predicted_output.to('cpu'),
                class2one_hot(torch.tensor(gt).to('cpu'), n_classes))[:, n, ]

            hauss = haussdorf(predicted_output,
                              class2one_hot(torch.tensor(gt), n_classes))[:,
                                                                          n, ]
            '''
            fig, ax = plt.subplots()
            ax.imshow(np.argmax(predicted_output.detach().numpy(), axis=1)[0], cmap=plt.cm.gray)
            #r, contours= Get_contour_characteristics(np.argmax(predicted_output.detach().numpy(), axis=1)[0])
            g, contours = Get_contour_characteristics(np.array(gt).round())
            total_summ = 0
            
            for n, contour in enumerate(contours):
                ax.plot(contour[:, 1].astype(int), contour[:, 0].astype(int),color='red', linewidth=-1)

                ax.axis('image')
        print(file)
        image = np.array(Image.open(os.path.join(root, 'img', file)))
        gt = np.array(Image.open(os.path.join(root, 'gt', file)))
        if len(np.unique(gt)) > 0:
            #print('infering {} of shape {} and classes {}, max {} and min {} '.format( file, image.shape, np.unique(gt), image.max(), image.min()))
            image = image.reshape(-1, 5, 256, 256) / 255.00
            image = torch.tensor(image, dtype=torch.float)
            image = Variable(image, requires_grad=True)
            pred = net(image)
            pred = F.softmax(pred, dim=1).to('cpu')
            predicted_output = probs2one_hot(pred.detach())
            np.save(os.path.join(path, 'predictions', '{}'.format(file)),
                    pred.to('cpu').detach().numpy())
            dice = dice_coef(
                predicted_output.to('cpu'),
                class2one_hot(torch.tensor(gt).to('cpu'), n_classes))
            hauss = haussdorf(
                predicted_output,
                class2one_hot(torch.tensor(gt).to('cpu'), n_classes))
            plt.imsave(
                os.path.join(path, 'predictions',
                             '{}.png'.format(file.split('.npy')[0])),
                np.argmax(predicted_output, 1)[0])
            plt.imsave(
                os.path.join(path, 'gt',
                             '{}.png'.format(file.split('.npy')[0])), gt)

            pred_label = len(
                np.unique(label(np.array(
                    pred.argmax(axis=1).detach().numpy()))))
            gt_label = len(np.unique(label(gt)))
Beispiel #14
0
def save_slices(
        flair_path,
        t1_path,
        gt_path,
        dest_dir: Path,
        shape: Tuple[int],
        n_augment: int,
        discard_negatives: bool,
        flair_dir: str = "flair",
        t1_dir="t1",
        gt_dir: str = "gt",
        in_npy_dir="in_npy",
        gt_npy_dir='gt_npy'
) -> Tuple[int, int, Dict, Tuple[float, float, float]]:
    p_id: str = get_p_id(flair_path)
    assert len(set(map_(get_p_id, [flair_path, t1_path, gt_path]))) == 1
    print(p_id)

    space_dict: Dict[str, Tuple[float, float]] = {}

    # Load the data
    dx, dy, dz = nib.load(str(flair_path)).header.get_zooms()
    # assert dx == dy, (dx, dy)
    flair = np.asarray(nib.load(str(flair_path)).dataobj)
    w, h, _ = flair.shape
    x, y, z = flair.shape
    t1 = np.asarray(nib.load(str(t1_path)).dataobj)
    gt = np.asarray(nib.load(str(gt_path)).dataobj)
    assert set(np.unique(gt)) <= set([0., 1., 2.])

    pos: int = (gt == 1).sum()
    neg: int = ((gt == 0) | (gt == 2)).sum()

    assert len(set(map_(np.shape, [flair, t1, gt]))) == 1
    assert flair.dtype in [np.float32], flair.dtype
    assert t1.dtype in [np.uint16], t1.dtype
    assert gt.dtype in [np.float32], gt.dtype

    # Normalize and check data content
    norm_flair = norm_arr(
        flair)  # We need to normalize the whole 3d img, not 2d slices
    norm_t1 = norm_arr(t1)
    norm_gt = gt.astype(np.uint8)
    assert 0 == norm_flair.min() and norm_flair.max() == 1, (norm_flair.min(),
                                                             norm_flair.max())
    assert 0 == norm_t1.min() and norm_t1.max() == 1, (norm_t1.min(),
                                                       norm_t1.max())
    assert np.array_equal(np.unique(gt), np.unique(norm_gt))

    resized_flair = resize(norm_flair, (256, 256, z),
                           mode='constant',
                           preserve_range=True,
                           anti_aliasing=False).astype(np.float32)
    resized_t1 = resize(norm_t1, (256, 256, z),
                        mode='constant',
                        preserve_range=True,
                        anti_aliasing=False).astype(np.float32)
    resized_gt = resize(norm_gt, (256, 256, z),
                        mode='constant',
                        preserve_range=True,
                        anti_aliasing=False,
                        order=0).astype(np.uint8)
    resized_gt[np.where(
        resized_gt == 2)] = 0  # Count those labels as background

    # Pre-compute the 3d distance map
    rx = dx * w / 256
    ry = dy * h / 256
    rz = dz
    # print(f"{flair.shape=}")
    # print(f"{(dx,dy,dz)=} {(rx,ry,rz)=}")

    one_hot_gt: Tensor = class2one_hot(torch.tensor(resized_gt[None, ...],
                                                    dtype=torch.int64),
                                       K=2)[0]
    assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape
    distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(),
                                       resolution=(rx, ry, rz),
                                       dtype=np.float32)

    save_dir_flair: Path = Path(dest_dir, flair_dir)
    save_dir_t1: Path = Path(dest_dir, t1_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_dir_in_npy: Path = Path(dest_dir, in_npy_dir)
    save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir)
    save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap")
    save_dirs = [save_dir_flair, save_dir_t1, save_dir_gt]

    for j in range(flair.shape[-1]):
        flair_s = resized_flair[:, :, j]
        t1_s = resized_t1[:, :, j]
        gt_s = resized_gt[:, :, j]

        dist_s = distmap[:, :, :, j]
        # if gt_s.sum() > 0:
        #     print(f"{dist_s.min()=} {dist_s.max()=}")
        #     _, axes = plt.subplots(nrows=1, ncols=2)
        #     axes[0].imshow(gt_s)
        #     axes[0].set_title("GT")

        #     tmp = axes[1].imshow(dist_s[1, ...])
        #     axes[1].set_title("Distance map")
        #     plt.colorbar(tmp, ax=axes[1])
        #     plt.show()

        slices = [flair_s, t1_s, gt_s]
        assert flair_s.shape == t1_s.shape == gt_s.shape == dist_s[
            0, ...].shape, ((x, y, z), flair_s.shape, dist_s.shape)
        # gt_s[np.where(gt_s == 2)] = 0  # Now do that part earlier
        assert set(np.unique(gt_s)).issubset([0, 1]), np.unique(gt_s)

        if discard_negatives and (gt_s.sum() == 0):
            continue

        for k in range(n_augment + 1):
            if k == 0:
                to_save = slices
            else:
                to_save = map_(np.asarray, augment_arr(*slices))
                assert to_save[0].shape == slices[0].shape, (to_save[0].shape,
                                                             slices[0].shape)

            filename = f"{p_id}_{k:02d}_{j:04d}"
            space_dict[filename] = (rx, ry)
            for save_dir, data in zip(save_dirs, to_save):
                save_dir.mkdir(parents=True, exist_ok=True)

                if "gt" not in str(save_dir):
                    img = (data * 255).astype(np.uint8)
                else:
                    img = data.astype(np.uint8)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename).with_suffix(".png")),
                           img)

            multimodal = np.stack(
                to_save[:-1])  # Do not include the ground truth
            assert 0 <= multimodal.min(), multimodal.min()
            assert multimodal.max() <= 1, multimodal.max()
            save_dir_in_npy.mkdir(parents=True, exist_ok=True)
            save_dir_gt_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_in_npy, filename).with_suffix(".npy"),
                multimodal)
            np.save(
                Path(save_dir_gt_npy, filename).with_suffix(".npy"),
                to_save[-1])

            save_dir_distmap_npy.mkdir(parents=True, exist_ok=True)
            np.save(
                Path(save_dir_distmap_npy, filename).with_suffix(".npy"),
                dist_s)

    return neg, pos, space_dict, (rx, ry, rz)