Esempio n. 1
0
def loadSubject(pid: int,
                leavebckg: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    #leavebckg==what border of bck to leave around subj (otherwise we cut the imgs so that
    # they are as tight around subj (mask==1) as possible).
    dx, dy, m, g, f, w = sorted(Path('POEM').rglob(f"*{pid}*.nii"))
    # order: dtx, dty, mask, gt, fat, wat
    wat = nib.load(str(w)).get_fdata()
    fat = nib.load(str(f)).get_fdata()
    gt = nib.load(str(g)).get_fdata()
    maska = nib.load(str(m)).get_fdata()
    x = nib.load(str(dx)).get_fdata()
    y = nib.load(str(dy)).get_fdata()
    gt = get_one_hot(
        gt * maska, 7
    )  #to make sure segms will only be done inside subj, lets multiply by mask:

    tmp_z = np.ones(maska.shape)
    tmp = maska.sum(axis=(0, 1))
    startz, endz = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1]
    tmp_z[:, :, startz] = 0
    tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1.

    z = maska * tmp_z  #create artificially, simply DT from left to right
    bd = dt_edt(maska)  #create artificially, simply DT from border
    bd = bd / np.max(bd)

    allin = np.stack([wat, fat, x, y, z, bd], axis=0)

    tmp = maska.sum(axis=(1, 2))
    startx, endx = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1]
    tmp = maska.sum(axis=(0, 2))
    starty, endy = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1]

    #new starts/ends based on the required border width:
    x, y, z = maska.shape
    startx = max(0, startx - leavebckg)
    starty = max(0, starty - leavebckg)
    startz = max(0, startz - leavebckg)
    endx = min(x, endx + leavebckg + 1)
    endy = min(y, endy + leavebckg + 1)
    endz = min(z, endz + leavebckg + 1)

    # print(("orig.sizes:", maska.shape))
    # print(("new slice:", (startx,endx,starty,endy,startz,endz)))
    maska = maska[startx:endx, starty:endy, startz:endz]
    allin = allin[:, startx:endx, starty:endy, startz:endz]
    # print(("new sizes:", maska.shape, allin.shape))
    #to make sure segms will only be done inside subj, lets multiply by mask:
    return allin * maska, gt[:, startx:endx, starty:endy, startz:endz], maska
Esempio n. 2
0
def cutPOEM2D(patch_size,
              outpath,
              make_subsampled=True,
              add_dts=True,
              sliced=1,
              sampling=None):
    #sliced je lahko 0,1 ali 2. pove po katerem indexu naredimo slice.
    #prepare folders for saving:
    outpath = f"{outpath}/TRAIN"
    pathlib.Path(outpath).mkdir(parents=True, exist_ok=True)
    for i in ['gt', 'in1', 'in2']:
        pathlib.Path(outpath, i).mkdir(parents=True, exist_ok=True)

    #POEM SLICING
    #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*")
    #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*")
    #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*")
    #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii")
    #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii")
    #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii")

    gt_paths = glob("POEM/segms/CroppedSegmNew*")
    wat_paths = glob("POEM/watfat/cropped*_wat*")
    fat_paths = glob("POEM/watfat/cropped*_fat*")
    dtx_paths = glob("POEM/distmaps/*x.nii")
    dty_paths = glob("POEM/distmaps/*y.nii")
    mask_paths = glob("POEM/masks/cropped*_mask.nii")

    gt_paths.sort()
    wat_paths.sort()
    fat_paths.sort()
    dtx_paths.sort()
    dty_paths.sort()
    mask_paths.sort()
    assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len(
        dtx_paths) == len(dty_paths) == len(mask_paths)

    nb_class = 7
    patch = patch_size // 2

    slicing = ":," * sliced + "slajs" + ",:" * (2 - sliced) + "]"
    print(f"\nSLICING: [{slicing}\n")
    for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths,
                                  dty_paths, mask_paths):
        PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]]
        assert len(np.unique(PIDs)) == 1
        PID = PIDs[0]
        print(f"Slicing nr {PID}...")
        wat = nib.load(w).get_fdata()
        fat = nib.load(f).get_fdata()
        gt = nib.load(g).get_fdata()
        x = nib.load(dx).get_fdata()
        y = nib.load(dy).get_fdata()
        maska = nib.load(m).get_fdata()

        tmp_z = np.ones(maska.shape)
        startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero(
            maska.sum(axis=(0, 1)))[0][-1]
        tmp_z[:, :, startz] = 0
        tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1.

        z = maska * tmp_z  #create artificially, simply DT from left to right
        bd = dt_edt(maska)  #create artificially, simply DT from border
        bd = bd / np.max(bd)

        gt = get_one_hot(gt, nb_class)  #new size C x H x W x D

        inx = wat.shape[1 - (sliced > 0)] // patch_size
        iny = wat.shape[2 - (sliced == 2)] // patch_size
        to_cut = 2  #min(4, inx*iny)

        dict_tmp = {}

        if sampling == None:
            #if no sampling given, we cut randomly a few (to_cut=2?) patches from EACH slice.
            # print((maska.shape, wat.shape))
            for slajs in range(maska.shape[sliced]):
                kjeso = eval(f"np.argwhere(maska[{slicing}==1)")
                if len(kjeso) > to_cut:
                    dict_tmp[slajs] = [
                        random.choice(kjeso) for i in range(to_cut)
                    ]

        else:
            assert len(
                sampling
            ) == nb_class, f"Sampling variable should be an array of length 7!"
            #let's make a dict of all the slices(keys) and indeces (value lists)
            for organ, nr_samples in enumerate(sampling):
                possible = np.argwhere((gt[organ, ...] * maska) == 1)
                Ll = len(possible)
                nr_sample = min(nr_samples, Ll)
                samp = random.sample(range(Ll), nr_sample)
                samples = possible[samp, ...]

                for onesample in samples:
                    if onesample[sliced] not in dict_tmp:
                        dict_tmp[onesample[sliced]] = []
                    dict_tmp[onesample[sliced]].append([
                        onesample[left] for left in range(3) if left != sliced
                    ])

        for slajs, indexes in tqdm(dict_tmp.items()):

            wat_tmp = np.pad(np.squeeze(eval(f"wat[{slicing}")),
                             (patch + 16, ),
                             mode='constant')
            fat_tmp = np.pad(np.squeeze(eval(f"fat[{slicing}")),
                             (patch + 16, ),
                             mode='constant')
            gt_tmp = np.pad(np.squeeze(eval(f"gt[:,{slicing}")),
                            ((0, 0), (patch + 16, patch + 16),
                             (patch + 16, patch + 16)),
                            mode='constant')
            x_tmp = np.pad(np.squeeze(eval(f"x[{slicing}")), (patch + 16, ),
                           mode='constant')
            y_tmp = np.pad(np.squeeze(eval(f"y[{slicing}")), (patch + 16, ),
                           mode='constant')
            z_tmp = np.pad(np.squeeze(eval(f"z[{slicing}")), (patch + 16, ),
                           mode='constant')
            bd_tmp = np.pad(np.squeeze(eval(f"bd[{slicing}")), (patch + 16, ),
                            mode='constant')

            for counter, index in enumerate(indexes):
                startx = index[0] + 16
                endx = index[0] + 16 + 2 * patch
                starty = index[1] + 16
                endy = index[1] + 16 + 2 * patch

                allin = [
                    wat_tmp[startx:endx, starty:endy], fat_tmp[startx:endx,
                                                               starty:endy]
                ]

                if add_dts:
                    allin.append(x_tmp[startx:endx, starty:endy])
                    allin.append(y_tmp[startx:endx, starty:endy])
                    allin.append(z_tmp[startx:endx, starty:endy])
                    allin.append(bd_tmp[startx:endx, starty:endy])

                allin = np.stack(allin, axis=0)
                gt_part = gt_tmp[:, startx:endx, starty:endy]

                np.save(f"{outpath}/in1/subj{PID}_{slajs}_{counter}", allin)
                np.save(f"{outpath}/gt/subj{PID}_{slajs}_{counter}", gt_part)

                if make_subsampled:
                    startx = startx - 16
                    endx = endx + 16
                    starty = starty - 16
                    endy = endy + 16

                    allin = [
                        wat_tmp[startx:endx:3, starty:endy:3],
                        fat_tmp[startx:endx:3, starty:endy:3]
                    ]
                    if add_dts:
                        allin.append(x_tmp[startx:endx:3, starty:endy:3])
                        allin.append(y_tmp[startx:endx:3, starty:endy:3])
                        allin.append(z_tmp[startx:endx:3, starty:endy:3])
                        allin.append(bd_tmp[startx:endx:3, starty:endy:3])
                    allin = np.stack(allin, axis=0)

                    np.save(f"{outpath}/in2/subj{PID}_{slajs}_{counter}",
                            allin)

    with open(f"{outpath}/datainfo.txt", "w") as info_file:
        info_file.write(f"""Sliced by dim {sliced}. \nPatch size: {patch_size}
                                    \nDTs: {add_dts}\nsubsmpl: {make_subsampled}
                                    \nsampling: {sampling}""")
Esempio n. 3
0
def cutEval(patch_size, pid_list=None):
    """patch_Size = how big patches to cut.
       pid_list = which subjs to cut. If None, all are cut."""

    #patch_size = 50
    outpath2 = pathlib.Path('POEM_eval', 'TwoD')
    outpath3 = pathlib.Path('POEM_eval', 'TriD')
    GTs2 = pathlib.Path('POEM_eval', 'GTs_2D')
    GTs3 = pathlib.Path('POEM_eval', 'GTs_3D')
    GTs2.mkdir(parents=True, exist_ok=True)
    GTs3.mkdir(parents=True, exist_ok=True)
    for i in ['in1', 'in2']:
        pathlib.Path(outpath2, i).mkdir(parents=True, exist_ok=True)
        pathlib.Path(outpath3, i).mkdir(parents=True, exist_ok=True)

    #check if everything already exists, to not cut twice:

    if pid_list == None:  #set it to all available pids
        pid_list = [
            getpid(filli) for filli in glob("POEM/segms/CroppedSegmNew*")
        ]
    existing_pid_list = [getpid(filli) for filli in glob("POEM_eval/GTs_2D/*")]
    allfilesexist = len(set(pid_list).union(
        set(existing_pid_list))) == len(pid_list)
    exists = pathlib.Path('POEM_eval', f'size{patch_size}.txt').is_file()
    if exists and allfilesexist:  #everything exists, do not recut
        print('Files already exist. Cutting stopped.')
        return None

    #otherwise remove all existing files. Unles only a few/irrelevant ones exist but are of correct size.
    if not exists:
        for filename in pathlib.Path("POEM_eval").rglob("s*.[nt][xp][yt]"):
            filename.unlink()

    #POEM SLICING
    #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*")
    #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*")
    #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*")
    #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii")
    #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii")
    #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii")

    gt_paths = [
        g for g in glob("POEM/segms/CroppedSegmNew*") if getpid(g) in pid_list
    ]
    wat_paths = [
        g for g in glob("POEM/watfat/cropped*_wat*") if getpid(g) in pid_list
    ]
    fat_paths = [
        g for g in glob("POEM/watfat/cropped*_fat*") if getpid(g) in pid_list
    ]
    dtx_paths = [
        g for g in glob("POEM/distmaps/*x.nii") if getpid(g) in pid_list
    ]
    dty_paths = [
        g for g in glob("POEM/distmaps/*y.nii") if getpid(g) in pid_list
    ]
    mask_paths = [
        g for g in glob("POEM/masks/cropped*_mask.nii")
        if getpid(g) in pid_list
    ]

    gt_paths.sort()
    wat_paths.sort()
    fat_paths.sort()
    dtx_paths.sort()
    dty_paths.sort()
    mask_paths.sort()
    assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len(
        dtx_paths) == len(dty_paths) == len(mask_paths)

    #debugging:
    #fat_paths, wat_paths, dtx_paths, dty_paths, gt_paths = fat_paths[:2], wat_paths[:2],dtx_paths[:2], dty_paths[:2], gt_paths[:2]

    nb_class = 7

    for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths,
                                  dty_paths, mask_paths):
        PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]]
        assert len(
            np.unique(PIDs)) == 1  #check that all paths lead to same subj
        PID = PIDs[0]

        print(f"Slicing nr {PID}...")
        wat = nib.load(w).get_fdata()
        fat = nib.load(f).get_fdata()
        gt = nib.load(g).get_fdata()
        x = nib.load(dx).get_fdata()
        y = nib.load(dy).get_fdata()
        maska = nib.load(m).get_fdata()

        tmp_z = np.ones(maska.shape)
        startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero(
            maska.sum(axis=(0, 1)))[0][-1]
        tmp_z[:, :, startz] = 0
        tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1.

        z = maska * tmp_z  #create artificially, simply DT from left to right
        bd = dt_edt(maska)  #create artificially, simply DT from border
        bd = bd / np.max(bd)

        allin = np.stack([wat, fat, x, y, z, bd], axis=0)

        #SAVE GT
        gt = get_one_hot(gt, nb_class)  #new size C x H x W x D
        #np.save(pathlib.Path(GTs, f"subj{PID}.npy"), gt)
        #SAVE 2D SLICES
        for s in range(wat.shape[1]):
            np.save(pathlib.Path(outpath2, 'in1', f"subj{PID}_{s}.npy"),
                    np.squeeze(allin[:, :, s, :]))
            np.save(pathlib.Path(outpath2, 'in2', f"subj{PID}_{s}.npy"),
                    np.squeeze(allin[:, 0::3, s, 0::3]))
            np.save(pathlib.Path(GTs2, f"subj{PID}_{s}.npy"),
                    np.squeeze(gt[:, :, s, :]))

        #SAVE 3D PATCHES
        #for easier subsampl. data, first pad with 0s:
        allin = np.pad(allin, ((0, ), (16, ), (16, ), (16, )), mode='constant')
        gt = np.pad(gt, ((0, ), (16, ), (16, ), (16, )), mode='constant')
        for i in range(16, wat.shape[0] + 16, (patch_size - 16)):
            for j in range(16, wat.shape[1] + 16, (patch_size - 16)):
                for k in range(16, wat.shape[2] + 16, (patch_size - 16)):
                    tmp_in1 = allin[:, i:i + 50, j:j + 50, k:k + 50]
                    tmp_in2 = allin[:, i - 16:i + 66:3, j - 16:j + 66:3,
                                    k - 16:k + 66:3]
                    tmp_gt = gt[:, i:i + 50, j:j + 50, k:k + 50]

                    #  print(f"in1: {tmp_in1.shape}, in2: {tmp_in2.shape}")
                    _, s10, s11, s12 = tmp_in1.shape
                    _, s20, s21, s22 = tmp_in2.shape
                    tmp_in1 = np.pad(tmp_in1, ((0, 0), (0, 50 - s10),
                                               (0, 50 - s11), (0, 50 - s12)),
                                     mode='constant')
                    tmp_gt = np.pad(tmp_gt, ((0, 0), (0, 50 - s10),
                                             (0, 50 - s11), (0, 50 - s12)),
                                    mode='constant')
                    tmp_in2 = np.pad(tmp_in2, ((0, 0), (0, 28 - s20),
                                               (0, 28 - s21), (0, 28 - s22)),
                                     mode='constant')
                    #  print(f"NEW: \t {tmp_in1.shape}, in2: {tmp_in2.shape}")

                    np.save(
                        pathlib.Path(outpath3, 'in1',
                                     f"subj{PID}_{i}_{j}_{k}.npy"), tmp_in1)
                    np.save(
                        pathlib.Path(outpath3, 'in2',
                                     f"subj{PID}_{i}_{j}_{k}.npy"), tmp_in2)
                    np.save(pathlib.Path(GTs3, f"subj{PID}_{i}_{j}_{k}.npy"),
                            tmp_gt)
    return None


# %%
#cutPOEM2D(50, 'POEM50', sampling=[5, 3, 4, 3, 5, 4, 4])
#cutPOEM2D(50, 'POEM50_2', sliced=2,sampling=[5, 3, 4, 3, 5, 4, 4])
#cutPOEM3D(50, 'POEM50_3D', sampling=[5,3,4,3,5,4,4])
Esempio n. 4
0
def cutPOEMslices():
    #by default cuts only in axial direction. This was just for tryouts; same date as in BL project.
    outpath = f"POEM_slices/TRAIN"
    pathlib.Path(outpath).mkdir(parents=True, exist_ok=True)
    for i in ['gt', 'in1', 'in2']:
        pathlib.Path(f"POEM_slices/{i}").mkdir(parents=True, exist_ok=True)

    #POEM SLICING
    gt_paths = glob(
        "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*"
    )
    wat_paths = glob(
        "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*"
    )
    fat_paths = glob(
        "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*"
    )
    dtx_paths = glob(
        "/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii")
    dty_paths = glob(
        "/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii")

    #gt_paths = glob("POEM/segms/CroppedSegmNew*")
    #wat_paths = glob("POEM/watfat/cropped*_wat*")
    #fat_paths = glob("POEM/watfat/cropped*_fat*")
    #dtx_paths = glob("POEM/distmaps/*x.nii")
    #dty_paths = glob("POEM/distmaps/*y.nii")
    #mask_paths = glob("POEM/masks/cropped*_mask.nii")

    gt_paths.sort()
    wat_paths.sort()
    fat_paths.sort()
    dtx_paths.sort()
    dty_paths.sort()

    for w, f, g, dx, dy in zip(wat_paths, fat_paths, gt_paths, dtx_paths,
                               dty_paths):
        PID = getpid(w)
        print(f"Slicing nr {PID}...")
        wat = nib.load(w).get_fdata()
        fat = nib.load(f).get_fdata()
        gt = nib.load(g).get_fdata()
        x = nib.load(dx).get_fdata()
        y = nib.load(dy).get_fdata()

        gt = get_one_hot(gt, 7)  #new size C x H x W x D

        slajsi_where = gt[1:, ...].sum(axis=(0, 1, 3))
        slajsi = np.arange(wat.shape[1])
        slajsi = slajsi[slajsi_where > 0]

        for slajs in tqdm(slajsi):
            allin = [
                wat[:, slajs, ...], fat[:, slajs, ...], x[:, slajs, ...],
                y[:, slajs, ...]
            ]
            allin = np.stack(allin, axis=0)
            quasidownsmp = allin[:, 0::3, 0::3]
            gt_part = gt[:, :, slajs, :]

            np.save(f"POEM_slices/in1/subj{PID}_{slajs}_0", allin)
            np.save(f"POEM_slices/in2/subj{PID}_{slajs}_0", quasidownsmp)
            np.save(f"POEM_slices/gt/subj{PID}_{slajs}_0", gt_part)
Esempio n. 5
0
def cutPOEM3D(patch_size,
              outpath,
              make_subsampled=True,
              add_dts=True,
              sampling=None):
    #sampling pove koliko patchov per class samplamo iz vsakega subjekta.
    # If not given, sampling is random. (ie may contain lots of bckg!)

    #prepare folders for saving:
    outpath = f"{outpath}/TRAIN"
    pathlib.Path(outpath).mkdir(parents=True, exist_ok=True)
    for i in ['gt', 'in1', 'in2']:
        pathlib.Path(outpath, i).mkdir(parents=True, exist_ok=True)

    #POEM SLICING
    #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*")
    #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*")
    #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*")
    #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii")
    #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii")
    #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii")

    gt_paths = glob("POEM/segms/CroppedSegmNew*")
    wat_paths = glob("POEM/watfat/cropped*_wat*")
    fat_paths = glob("POEM/watfat/cropped*_fat*")
    dtx_paths = glob("POEM/distmaps/*x.nii")
    dty_paths = glob("POEM/distmaps/*y.nii")
    mask_paths = glob("POEM/masks/cropped*_mask.nii")

    gt_paths.sort()
    wat_paths.sort()
    fat_paths.sort()
    dtx_paths.sort()
    dty_paths.sort()
    mask_paths.sort()
    assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len(
        dtx_paths) == len(dty_paths) == len(mask_paths)

    nb_class = 7
    patch = patch_size // 2

    for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths,
                                  dty_paths, mask_paths):
        PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]]
        assert len(np.unique(PIDs)) == 1
        PID = PIDs[0]
        print(f"Slicing nr {PID}...")
        wat = nib.load(w).get_fdata()
        fat = nib.load(f).get_fdata()
        gt = nib.load(g).get_fdata()
        x = nib.load(dx).get_fdata()
        y = nib.load(dy).get_fdata()
        maska = nib.load(m).get_fdata()

        tmp_z = np.ones(maska.shape)
        startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero(
            maska.sum(axis=(0, 1)))[0][-1]
        tmp_z[:, :, startz] = 0
        tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1.

        z = maska * tmp_z  #create artificially, simply DT from left to right
        bd = dt_edt(maska)  #create artificially, simply DT from border
        bd = bd / np.max(bd)

        gt = get_one_hot(gt, nb_class)  #new size C x H x W x D

        inx = wat.shape[0] // patch_size
        iny = wat.shape[1] // patch_size
        inz = wat.shape[2] // patch_size
        to_cut = 5

        if sampling == None:
            #if no sampling given, we cut randomly a few (to_cut=2?) patches from EACH slice.
            # print((maska.shape, wat.shape))
            kjeso = np.argwhere(maska == 1)
            if len(kjeso) > to_cut:
                kjeso = kjeso[
                    np.random.choice(kjeso.shape[0], to_cut, replace=False),
                    ...]

        else:
            assert len(
                sampling
            ) == nb_class, f"Sampling variable should be an array of length 7!"
            #let's make a dict of all the slices(keys) and indeces (value lists)
            kjeso = []
            for organ, nr_samples in enumerate(sampling):
                possible = np.argwhere((gt[organ, ...] * maska) == 1)
                Ll = len(possible)
                nr_sample = min(nr_samples, Ll)
                kjeso.append(possible[random.sample(range(Ll), nr_sample),
                                      ...])
            kjeso = np.vstack(kjeso)

        wat_tmp = np.pad(wat, (patch + 16, ), mode='constant')
        fat_tmp = np.pad(fat, (patch + 16, ), mode='constant')
        gt_tmp = np.pad(gt,
                        ((0, 0), (patch + 16, patch + 16),
                         (patch + 16, patch + 16), (patch + 16, patch + 16)),
                        mode='constant')
        x_tmp = np.pad(x, (patch + 16, ), mode='constant')
        y_tmp = np.pad(y, (patch + 16, ), mode='constant')
        z_tmp = np.pad(z, (patch + 16, ), mode='constant')
        bd_tmp = np.pad(bd, (patch + 16, ), mode='constant')

        for idx, center in tqdm(enumerate(kjeso)):
            startx = center[0] + 16
            endx = center[0] + 16 + 2 * patch
            starty = center[1] + 16
            endy = center[1] + 16 + 2 * patch
            startz = center[2] + 16
            endz = center[2] + 16 + 2 * patch

            allin = [
                wat_tmp[startx:endx, starty:endy, startz:endz],
                fat_tmp[startx:endx, starty:endy, startz:endz]
            ]

            if add_dts:
                allin.append(x_tmp[startx:endx, starty:endy, startz:endz])
                allin.append(y_tmp[startx:endx, starty:endy, startz:endz])
                allin.append(z_tmp[startx:endx, starty:endy, startz:endz])
                allin.append(bd_tmp[startx:endx, starty:endy, startz:endz])

            allin = np.stack(allin, axis=0)
            gt_part = gt_tmp[:, startx:endx, starty:endy, startz:endz]

            np.save(f"{outpath}/in1/subj{PID}_{idx}_0", allin)
            np.save(f"{outpath}/gt/subj{PID}_{idx}_0", gt_part)

            if make_subsampled:
                startx = startx - 16
                endx = endx + 16
                starty = starty - 16
                endy = endy + 16
                startz = startz - 16
                endz = endz + 16

                allin = [
                    wat_tmp[startx:endx:3, starty:endy:3, startz:endz:3],
                    fat_tmp[startx:endx:3, starty:endy:3, startz:endz:3]
                ]
                if add_dts:
                    allin.append(x_tmp[startx:endx:3, starty:endy:3,
                                       startz:endz:3])
                    allin.append(y_tmp[startx:endx:3, starty:endy:3,
                                       startz:endz:3])
                    allin.append(z_tmp[startx:endx:3, starty:endy:3,
                                       startz:endz:3])
                    allin.append(bd_tmp[startx:endx:3, starty:endy:3,
                                        startz:endz:3])
                allin = np.stack(allin, axis=0)

                np.save(f"{outpath}/in2/subj{PID}_{idx}_0", allin)

    with open(f"{outpath}/datainfo.txt", "w") as info_file:
        info_file.write(f"""Sliced 3D patches. \nPatch size: {patch_size}
                            \nDTs: {add_dts}\nsubsmpl: {make_subsampled}
                            \nsampling: {sampling}""")
Esempio n. 6
0
def plotOutput(params, datafolder, pids, doeval=True, take20=None):
    """After training a net and saving its state to name PARAMS,
         run inference on subject PID from DATAFOLDER, and plot results+GT. 
         PID should be subject_slice, and all subject_slice_x will be run.
         E.g. plotOutput('First_unet', 'POEM', '500177_30'). """

    #default settings:
    Arg = {
        'network': None,
        'n_class': 7,
        'in_channels': 2,
        'lower_in_channels': 2,
        'extractor_net': 'resnet34'
    }

    with open(f"RESULTS/{params}_args.txt", "r") as ft:
        args = ft.read().splitlines()
    tmpargs = [
        i.strip('--').split("=") for i in args if ('--' in i and '=' in i)
    ]
    chan1, whichin1 = getchn(args, 'in_chan')
    chan2, whichin2 = getchn(args, 'lower_in_chan')
    tmpargs += [['in_channels', chan1], ['lower_in_channels', chan2]]
    in3D = '--in3D' in args
    args = dict(tmpargs)

    #overwrite if given in file:
    Arg.update(args)
    device = torch.device('cpu')

    net = getattr(Networks, Arg['network'])(Arg['in_channels'], Arg['n_class'],
                                            Arg['lower_in_channels'],
                                            Arg['extractor_net'], in3D)
    net = net.float()
    #now we can load learned params:
    loaded = torch.load(f"RESULTS/{params}",
                        map_location=lambda storage, loc: storage)
    net.load_state_dict(loaded['state_dict'])

    if doeval:
        net.eval()

    net = net.to(device)
    #load data&GT
    if isinstance(pids, str):
        pids = [pids]

    use_in2 = Arg['network'] == 'DeepMedic'
    allL = 0
    allfindgts = []
    allfindin1 = []
    allfindin2 = []
    for pid in pids:
        findgts = glob.glob(f"./{datafolder}/*/gt/*{pid}*.npy")
        #findgts = glob.glob(f"./{datafolder}/GTs_2D/*{pid}*.npy")
        findin1 = glob.glob(f"./{datafolder}/*/in1/*{pid}*.npy")
        findin2 = glob.glob(f"./{datafolder}/*/in2/*{pid}*.npy")
        findgts.sort(), findin1.sort(), findin2.sort()
        #  print(findgts)

        #all subslices in one image.
        L = len(findgts)
        if L > 20:  #ugly but needed to avoid too long compute
            if take20 == None:
                take20 = random.sample(range(L), 20)
            findgts = [findgts[tk] for tk in take20]
            findin1 = [findin1[tk] for tk in take20]
            if use_in2:
                findin2 = [findin2[tk] for tk in take20]
            L = len(take20)

        allL += L
        allfindgts.extend(findgts)
        allfindin1.extend(findin1)
        allfindin2.extend(findin2)

    organs = [
        'Bckg', 'Bladder', 'KidneyL', 'Liver', 'Pancreas', 'Spleen', 'KidneyR'
    ]
    if len(organs) != Arg['n_class']:  #in case not POEM dataset used
        organs = [str(zblj) for zblj in range(Arg['n_class'])]

    entmpgt = np.load(allfindgts[0])
    tgtonehot = entmpgt.shape[0] == 7  #are targets one hot encoded?
    in3d = tgtonehot * (entmpgt.ndim == 4) + (not tgtonehot) * (entmpgt.ndim
                                                                == 3)
    if in3d:
        #set the right function to use
        TensorCropping = CenterCropTensor3d
    else:
        TensorCropping = CenterCropTensor

    data = torch.stack([
        torch.from_numpy(np.load(i1)).float().to(device) for i1 in allfindin1
    ],
                       dim=0)
    data = [data[:, whichin1, ...]]
    target = [
        flatten_one_hot(np.load(g)) if tgtonehot else np.load(g)
        for g in allfindgts
    ]
    target_oh = torch.stack([
        torch.from_numpy(np.load(g)).to(device) if tgtonehot else
        torch.from_numpy(get_one_hot(np.load(g), 7)).to(device)
        for g in allfindgts
    ],
                            dim=0)

    if use_in2:
        in2 = torch.stack([
            torch.from_numpy(np.load(i2)).float().to(device)
            for i2 in allfindin2
        ],
                          dim=0)
        data.append(in2[:, whichin2, ...])

    out = net(*data)
    target_oh, out = TensorCropping(target_oh, out)
    dices = AllDices(out, target_oh)  #DicePerClass(out, target_oh)
    #  print((out.shape, target_oh.shape))
    outs = [flatten_one_hot(o.detach().squeeze().numpy()) for o in out]

    fig, ax_tuple = plt.subplots(nrows=allL,
                                 ncols=2,
                                 figsize=(10, allL * 6 + 1),
                                 tight_layout=True)
    #for compatibility reasons:
    if ax_tuple.ndim < 2:
        ax_tuple = ax_tuple[np.newaxis, ...]
    plt.suptitle(params)
    for ind in range(len(outs)):
        #now plot :)
        targetind, outsind = TensorCropping(
            target[ind], outs[ind])  #crop to be more comparable
        #  print((outsind.shape, targetind.shape))
        if in3d:
            sl = targetind.shape[-2] // 2
            targetind, outsind = targetind[..., sl, :], outsind[..., sl, :]

        ax1 = ax_tuple[ind, 0]
        ax1.set_title('GT')
        ax1.axis('off')
        ax1.imshow(targetind, cmap='Spectral', vmin=0, vmax=Arg['n_class'])

        ax2 = ax_tuple[ind, 1]
        ax2.set_title('OUT')
        ax2.axis('off')
        im = ax2.imshow(outsind, cmap='Spectral', vmin=0, vmax=Arg['n_class'])

        values = np.arange(Arg['n_class'])
        colors = [im.cmap(im.norm(value)) for value in values]
        # create a patch (proxy artist) for every color
        patches = [
            mpatches.Patch(color=colors[i], label=organs[i])
            for i in range(len(values))
        ]
        # put those patched as legend-handles into the legend
        ax2.legend(handles=patches,
                   bbox_to_anchor=(1.05, 1.),
                   loc=2,
                   borderaxespad=0.)

        #write out also Dices:
        dajci = dices[ind].detach().squeeze().numpy()
        present_classes = [i for i in range(7) if i in target[ind]]
        t = ax2.text(1.08,
                     0.5,
                     'Dices:',
                     size='medium',
                     horizontalalignment='center',
                     verticalalignment='center',
                     transform=ax2.transAxes)
        for d in range(7):
            t = ax2.text(1.1,
                         0.45 - d * 0.05,
                         f"{organs[d]}: {dajci[d]:.3f}",
                         size='small',
                         transform=ax2.transAxes)

    plt.show()
    #plt.savefig('foo.png')
    #print(dices)
    return take20