def loadSubject(pid: int, leavebckg: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: #leavebckg==what border of bck to leave around subj (otherwise we cut the imgs so that # they are as tight around subj (mask==1) as possible). dx, dy, m, g, f, w = sorted(Path('POEM').rglob(f"*{pid}*.nii")) # order: dtx, dty, mask, gt, fat, wat wat = nib.load(str(w)).get_fdata() fat = nib.load(str(f)).get_fdata() gt = nib.load(str(g)).get_fdata() maska = nib.load(str(m)).get_fdata() x = nib.load(str(dx)).get_fdata() y = nib.load(str(dy)).get_fdata() gt = get_one_hot( gt * maska, 7 ) #to make sure segms will only be done inside subj, lets multiply by mask: tmp_z = np.ones(maska.shape) tmp = maska.sum(axis=(0, 1)) startz, endz = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1] tmp_z[:, :, startz] = 0 tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1. z = maska * tmp_z #create artificially, simply DT from left to right bd = dt_edt(maska) #create artificially, simply DT from border bd = bd / np.max(bd) allin = np.stack([wat, fat, x, y, z, bd], axis=0) tmp = maska.sum(axis=(1, 2)) startx, endx = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1] tmp = maska.sum(axis=(0, 2)) starty, endy = np.nonzero(tmp)[0][0], np.nonzero(tmp)[0][-1] #new starts/ends based on the required border width: x, y, z = maska.shape startx = max(0, startx - leavebckg) starty = max(0, starty - leavebckg) startz = max(0, startz - leavebckg) endx = min(x, endx + leavebckg + 1) endy = min(y, endy + leavebckg + 1) endz = min(z, endz + leavebckg + 1) # print(("orig.sizes:", maska.shape)) # print(("new slice:", (startx,endx,starty,endy,startz,endz))) maska = maska[startx:endx, starty:endy, startz:endz] allin = allin[:, startx:endx, starty:endy, startz:endz] # print(("new sizes:", maska.shape, allin.shape)) #to make sure segms will only be done inside subj, lets multiply by mask: return allin * maska, gt[:, startx:endx, starty:endy, startz:endz], maska
def cutPOEM2D(patch_size, outpath, make_subsampled=True, add_dts=True, sliced=1, sampling=None): #sliced je lahko 0,1 ali 2. pove po katerem indexu naredimo slice. #prepare folders for saving: outpath = f"{outpath}/TRAIN" pathlib.Path(outpath).mkdir(parents=True, exist_ok=True) for i in ['gt', 'in1', 'in2']: pathlib.Path(outpath, i).mkdir(parents=True, exist_ok=True) #POEM SLICING #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*") #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*") #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*") #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii") #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii") #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii") gt_paths = glob("POEM/segms/CroppedSegmNew*") wat_paths = glob("POEM/watfat/cropped*_wat*") fat_paths = glob("POEM/watfat/cropped*_fat*") dtx_paths = glob("POEM/distmaps/*x.nii") dty_paths = glob("POEM/distmaps/*y.nii") mask_paths = glob("POEM/masks/cropped*_mask.nii") gt_paths.sort() wat_paths.sort() fat_paths.sort() dtx_paths.sort() dty_paths.sort() mask_paths.sort() assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len( dtx_paths) == len(dty_paths) == len(mask_paths) nb_class = 7 patch = patch_size // 2 slicing = ":," * sliced + "slajs" + ",:" * (2 - sliced) + "]" print(f"\nSLICING: [{slicing}\n") for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths, dty_paths, mask_paths): PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]] assert len(np.unique(PIDs)) == 1 PID = PIDs[0] print(f"Slicing nr {PID}...") wat = nib.load(w).get_fdata() fat = nib.load(f).get_fdata() gt = nib.load(g).get_fdata() x = nib.load(dx).get_fdata() y = nib.load(dy).get_fdata() maska = nib.load(m).get_fdata() tmp_z = np.ones(maska.shape) startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero( maska.sum(axis=(0, 1)))[0][-1] tmp_z[:, :, startz] = 0 tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1. z = maska * tmp_z #create artificially, simply DT from left to right bd = dt_edt(maska) #create artificially, simply DT from border bd = bd / np.max(bd) gt = get_one_hot(gt, nb_class) #new size C x H x W x D inx = wat.shape[1 - (sliced > 0)] // patch_size iny = wat.shape[2 - (sliced == 2)] // patch_size to_cut = 2 #min(4, inx*iny) dict_tmp = {} if sampling == None: #if no sampling given, we cut randomly a few (to_cut=2?) patches from EACH slice. # print((maska.shape, wat.shape)) for slajs in range(maska.shape[sliced]): kjeso = eval(f"np.argwhere(maska[{slicing}==1)") if len(kjeso) > to_cut: dict_tmp[slajs] = [ random.choice(kjeso) for i in range(to_cut) ] else: assert len( sampling ) == nb_class, f"Sampling variable should be an array of length 7!" #let's make a dict of all the slices(keys) and indeces (value lists) for organ, nr_samples in enumerate(sampling): possible = np.argwhere((gt[organ, ...] * maska) == 1) Ll = len(possible) nr_sample = min(nr_samples, Ll) samp = random.sample(range(Ll), nr_sample) samples = possible[samp, ...] for onesample in samples: if onesample[sliced] not in dict_tmp: dict_tmp[onesample[sliced]] = [] dict_tmp[onesample[sliced]].append([ onesample[left] for left in range(3) if left != sliced ]) for slajs, indexes in tqdm(dict_tmp.items()): wat_tmp = np.pad(np.squeeze(eval(f"wat[{slicing}")), (patch + 16, ), mode='constant') fat_tmp = np.pad(np.squeeze(eval(f"fat[{slicing}")), (patch + 16, ), mode='constant') gt_tmp = np.pad(np.squeeze(eval(f"gt[:,{slicing}")), ((0, 0), (patch + 16, patch + 16), (patch + 16, patch + 16)), mode='constant') x_tmp = np.pad(np.squeeze(eval(f"x[{slicing}")), (patch + 16, ), mode='constant') y_tmp = np.pad(np.squeeze(eval(f"y[{slicing}")), (patch + 16, ), mode='constant') z_tmp = np.pad(np.squeeze(eval(f"z[{slicing}")), (patch + 16, ), mode='constant') bd_tmp = np.pad(np.squeeze(eval(f"bd[{slicing}")), (patch + 16, ), mode='constant') for counter, index in enumerate(indexes): startx = index[0] + 16 endx = index[0] + 16 + 2 * patch starty = index[1] + 16 endy = index[1] + 16 + 2 * patch allin = [ wat_tmp[startx:endx, starty:endy], fat_tmp[startx:endx, starty:endy] ] if add_dts: allin.append(x_tmp[startx:endx, starty:endy]) allin.append(y_tmp[startx:endx, starty:endy]) allin.append(z_tmp[startx:endx, starty:endy]) allin.append(bd_tmp[startx:endx, starty:endy]) allin = np.stack(allin, axis=0) gt_part = gt_tmp[:, startx:endx, starty:endy] np.save(f"{outpath}/in1/subj{PID}_{slajs}_{counter}", allin) np.save(f"{outpath}/gt/subj{PID}_{slajs}_{counter}", gt_part) if make_subsampled: startx = startx - 16 endx = endx + 16 starty = starty - 16 endy = endy + 16 allin = [ wat_tmp[startx:endx:3, starty:endy:3], fat_tmp[startx:endx:3, starty:endy:3] ] if add_dts: allin.append(x_tmp[startx:endx:3, starty:endy:3]) allin.append(y_tmp[startx:endx:3, starty:endy:3]) allin.append(z_tmp[startx:endx:3, starty:endy:3]) allin.append(bd_tmp[startx:endx:3, starty:endy:3]) allin = np.stack(allin, axis=0) np.save(f"{outpath}/in2/subj{PID}_{slajs}_{counter}", allin) with open(f"{outpath}/datainfo.txt", "w") as info_file: info_file.write(f"""Sliced by dim {sliced}. \nPatch size: {patch_size} \nDTs: {add_dts}\nsubsmpl: {make_subsampled} \nsampling: {sampling}""")
def cutEval(patch_size, pid_list=None): """patch_Size = how big patches to cut. pid_list = which subjs to cut. If None, all are cut.""" #patch_size = 50 outpath2 = pathlib.Path('POEM_eval', 'TwoD') outpath3 = pathlib.Path('POEM_eval', 'TriD') GTs2 = pathlib.Path('POEM_eval', 'GTs_2D') GTs3 = pathlib.Path('POEM_eval', 'GTs_3D') GTs2.mkdir(parents=True, exist_ok=True) GTs3.mkdir(parents=True, exist_ok=True) for i in ['in1', 'in2']: pathlib.Path(outpath2, i).mkdir(parents=True, exist_ok=True) pathlib.Path(outpath3, i).mkdir(parents=True, exist_ok=True) #check if everything already exists, to not cut twice: if pid_list == None: #set it to all available pids pid_list = [ getpid(filli) for filli in glob("POEM/segms/CroppedSegmNew*") ] existing_pid_list = [getpid(filli) for filli in glob("POEM_eval/GTs_2D/*")] allfilesexist = len(set(pid_list).union( set(existing_pid_list))) == len(pid_list) exists = pathlib.Path('POEM_eval', f'size{patch_size}.txt').is_file() if exists and allfilesexist: #everything exists, do not recut print('Files already exist. Cutting stopped.') return None #otherwise remove all existing files. Unles only a few/irrelevant ones exist but are of correct size. if not exists: for filename in pathlib.Path("POEM_eval").rglob("s*.[nt][xp][yt]"): filename.unlink() #POEM SLICING #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*") #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*") #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*") #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii") #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii") #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii") gt_paths = [ g for g in glob("POEM/segms/CroppedSegmNew*") if getpid(g) in pid_list ] wat_paths = [ g for g in glob("POEM/watfat/cropped*_wat*") if getpid(g) in pid_list ] fat_paths = [ g for g in glob("POEM/watfat/cropped*_fat*") if getpid(g) in pid_list ] dtx_paths = [ g for g in glob("POEM/distmaps/*x.nii") if getpid(g) in pid_list ] dty_paths = [ g for g in glob("POEM/distmaps/*y.nii") if getpid(g) in pid_list ] mask_paths = [ g for g in glob("POEM/masks/cropped*_mask.nii") if getpid(g) in pid_list ] gt_paths.sort() wat_paths.sort() fat_paths.sort() dtx_paths.sort() dty_paths.sort() mask_paths.sort() assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len( dtx_paths) == len(dty_paths) == len(mask_paths) #debugging: #fat_paths, wat_paths, dtx_paths, dty_paths, gt_paths = fat_paths[:2], wat_paths[:2],dtx_paths[:2], dty_paths[:2], gt_paths[:2] nb_class = 7 for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths, dty_paths, mask_paths): PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]] assert len( np.unique(PIDs)) == 1 #check that all paths lead to same subj PID = PIDs[0] print(f"Slicing nr {PID}...") wat = nib.load(w).get_fdata() fat = nib.load(f).get_fdata() gt = nib.load(g).get_fdata() x = nib.load(dx).get_fdata() y = nib.load(dy).get_fdata() maska = nib.load(m).get_fdata() tmp_z = np.ones(maska.shape) startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero( maska.sum(axis=(0, 1)))[0][-1] tmp_z[:, :, startz] = 0 tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1. z = maska * tmp_z #create artificially, simply DT from left to right bd = dt_edt(maska) #create artificially, simply DT from border bd = bd / np.max(bd) allin = np.stack([wat, fat, x, y, z, bd], axis=0) #SAVE GT gt = get_one_hot(gt, nb_class) #new size C x H x W x D #np.save(pathlib.Path(GTs, f"subj{PID}.npy"), gt) #SAVE 2D SLICES for s in range(wat.shape[1]): np.save(pathlib.Path(outpath2, 'in1', f"subj{PID}_{s}.npy"), np.squeeze(allin[:, :, s, :])) np.save(pathlib.Path(outpath2, 'in2', f"subj{PID}_{s}.npy"), np.squeeze(allin[:, 0::3, s, 0::3])) np.save(pathlib.Path(GTs2, f"subj{PID}_{s}.npy"), np.squeeze(gt[:, :, s, :])) #SAVE 3D PATCHES #for easier subsampl. data, first pad with 0s: allin = np.pad(allin, ((0, ), (16, ), (16, ), (16, )), mode='constant') gt = np.pad(gt, ((0, ), (16, ), (16, ), (16, )), mode='constant') for i in range(16, wat.shape[0] + 16, (patch_size - 16)): for j in range(16, wat.shape[1] + 16, (patch_size - 16)): for k in range(16, wat.shape[2] + 16, (patch_size - 16)): tmp_in1 = allin[:, i:i + 50, j:j + 50, k:k + 50] tmp_in2 = allin[:, i - 16:i + 66:3, j - 16:j + 66:3, k - 16:k + 66:3] tmp_gt = gt[:, i:i + 50, j:j + 50, k:k + 50] # print(f"in1: {tmp_in1.shape}, in2: {tmp_in2.shape}") _, s10, s11, s12 = tmp_in1.shape _, s20, s21, s22 = tmp_in2.shape tmp_in1 = np.pad(tmp_in1, ((0, 0), (0, 50 - s10), (0, 50 - s11), (0, 50 - s12)), mode='constant') tmp_gt = np.pad(tmp_gt, ((0, 0), (0, 50 - s10), (0, 50 - s11), (0, 50 - s12)), mode='constant') tmp_in2 = np.pad(tmp_in2, ((0, 0), (0, 28 - s20), (0, 28 - s21), (0, 28 - s22)), mode='constant') # print(f"NEW: \t {tmp_in1.shape}, in2: {tmp_in2.shape}") np.save( pathlib.Path(outpath3, 'in1', f"subj{PID}_{i}_{j}_{k}.npy"), tmp_in1) np.save( pathlib.Path(outpath3, 'in2', f"subj{PID}_{i}_{j}_{k}.npy"), tmp_in2) np.save(pathlib.Path(GTs3, f"subj{PID}_{i}_{j}_{k}.npy"), tmp_gt) return None # %% #cutPOEM2D(50, 'POEM50', sampling=[5, 3, 4, 3, 5, 4, 4]) #cutPOEM2D(50, 'POEM50_2', sliced=2,sampling=[5, 3, 4, 3, 5, 4, 4]) #cutPOEM3D(50, 'POEM50_3D', sampling=[5,3,4,3,5,4,4])
def cutPOEMslices(): #by default cuts only in axial direction. This was just for tryouts; same date as in BL project. outpath = f"POEM_slices/TRAIN" pathlib.Path(outpath).mkdir(parents=True, exist_ok=True) for i in ['gt', 'in1', 'in2']: pathlib.Path(f"POEM_slices/{i}").mkdir(parents=True, exist_ok=True) #POEM SLICING gt_paths = glob( "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*" ) wat_paths = glob( "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*" ) fat_paths = glob( "/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*" ) dtx_paths = glob( "/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii") dty_paths = glob( "/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii") #gt_paths = glob("POEM/segms/CroppedSegmNew*") #wat_paths = glob("POEM/watfat/cropped*_wat*") #fat_paths = glob("POEM/watfat/cropped*_fat*") #dtx_paths = glob("POEM/distmaps/*x.nii") #dty_paths = glob("POEM/distmaps/*y.nii") #mask_paths = glob("POEM/masks/cropped*_mask.nii") gt_paths.sort() wat_paths.sort() fat_paths.sort() dtx_paths.sort() dty_paths.sort() for w, f, g, dx, dy in zip(wat_paths, fat_paths, gt_paths, dtx_paths, dty_paths): PID = getpid(w) print(f"Slicing nr {PID}...") wat = nib.load(w).get_fdata() fat = nib.load(f).get_fdata() gt = nib.load(g).get_fdata() x = nib.load(dx).get_fdata() y = nib.load(dy).get_fdata() gt = get_one_hot(gt, 7) #new size C x H x W x D slajsi_where = gt[1:, ...].sum(axis=(0, 1, 3)) slajsi = np.arange(wat.shape[1]) slajsi = slajsi[slajsi_where > 0] for slajs in tqdm(slajsi): allin = [ wat[:, slajs, ...], fat[:, slajs, ...], x[:, slajs, ...], y[:, slajs, ...] ] allin = np.stack(allin, axis=0) quasidownsmp = allin[:, 0::3, 0::3] gt_part = gt[:, :, slajs, :] np.save(f"POEM_slices/in1/subj{PID}_{slajs}_0", allin) np.save(f"POEM_slices/in2/subj{PID}_{slajs}_0", quasidownsmp) np.save(f"POEM_slices/gt/subj{PID}_{slajs}_0", gt_part)
def cutPOEM3D(patch_size, outpath, make_subsampled=True, add_dts=True, sampling=None): #sampling pove koliko patchov per class samplamo iz vsakega subjekta. # If not given, sampling is random. (ie may contain lots of bckg!) #prepare folders for saving: outpath = f"{outpath}/TRAIN" pathlib.Path(outpath).mkdir(parents=True, exist_ok=True) for i in ['gt', 'in1', 'in2']: pathlib.Path(outpath, i).mkdir(parents=True, exist_ok=True) #POEM SLICING #gt_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segment_all/converted/CroppedSegmNew*") #wat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_wat*") #fat_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_fat*") #dtx_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*x.nii") #dty_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/distmaps/*y.nii") #mask_paths = glob("/home/eva/Desktop/research/PROJEKT2-DeepLearning/procesiranDataset/POEM_segmentation_data_fatwat/converted/cropped*_mask.nii") gt_paths = glob("POEM/segms/CroppedSegmNew*") wat_paths = glob("POEM/watfat/cropped*_wat*") fat_paths = glob("POEM/watfat/cropped*_fat*") dtx_paths = glob("POEM/distmaps/*x.nii") dty_paths = glob("POEM/distmaps/*y.nii") mask_paths = glob("POEM/masks/cropped*_mask.nii") gt_paths.sort() wat_paths.sort() fat_paths.sort() dtx_paths.sort() dty_paths.sort() mask_paths.sort() assert len(gt_paths) == len(wat_paths) == len(fat_paths) == len( dtx_paths) == len(dty_paths) == len(mask_paths) nb_class = 7 patch = patch_size // 2 for w, f, g, dx, dy, m in zip(wat_paths, fat_paths, gt_paths, dtx_paths, dty_paths, mask_paths): PIDs = [getpid(ppp) for ppp in [w, f, g, dx, dy, m]] assert len(np.unique(PIDs)) == 1 PID = PIDs[0] print(f"Slicing nr {PID}...") wat = nib.load(w).get_fdata() fat = nib.load(f).get_fdata() gt = nib.load(g).get_fdata() x = nib.load(dx).get_fdata() y = nib.load(dy).get_fdata() maska = nib.load(m).get_fdata() tmp_z = np.ones(maska.shape) startz, endz = np.nonzero(maska.sum(axis=(0, 1)))[0][0], np.nonzero( maska.sum(axis=(0, 1)))[0][-1] tmp_z[:, :, startz] = 0 tmp_z = 2. * dt_edt(tmp_z) / (endz - startz) - 1. z = maska * tmp_z #create artificially, simply DT from left to right bd = dt_edt(maska) #create artificially, simply DT from border bd = bd / np.max(bd) gt = get_one_hot(gt, nb_class) #new size C x H x W x D inx = wat.shape[0] // patch_size iny = wat.shape[1] // patch_size inz = wat.shape[2] // patch_size to_cut = 5 if sampling == None: #if no sampling given, we cut randomly a few (to_cut=2?) patches from EACH slice. # print((maska.shape, wat.shape)) kjeso = np.argwhere(maska == 1) if len(kjeso) > to_cut: kjeso = kjeso[ np.random.choice(kjeso.shape[0], to_cut, replace=False), ...] else: assert len( sampling ) == nb_class, f"Sampling variable should be an array of length 7!" #let's make a dict of all the slices(keys) and indeces (value lists) kjeso = [] for organ, nr_samples in enumerate(sampling): possible = np.argwhere((gt[organ, ...] * maska) == 1) Ll = len(possible) nr_sample = min(nr_samples, Ll) kjeso.append(possible[random.sample(range(Ll), nr_sample), ...]) kjeso = np.vstack(kjeso) wat_tmp = np.pad(wat, (patch + 16, ), mode='constant') fat_tmp = np.pad(fat, (patch + 16, ), mode='constant') gt_tmp = np.pad(gt, ((0, 0), (patch + 16, patch + 16), (patch + 16, patch + 16), (patch + 16, patch + 16)), mode='constant') x_tmp = np.pad(x, (patch + 16, ), mode='constant') y_tmp = np.pad(y, (patch + 16, ), mode='constant') z_tmp = np.pad(z, (patch + 16, ), mode='constant') bd_tmp = np.pad(bd, (patch + 16, ), mode='constant') for idx, center in tqdm(enumerate(kjeso)): startx = center[0] + 16 endx = center[0] + 16 + 2 * patch starty = center[1] + 16 endy = center[1] + 16 + 2 * patch startz = center[2] + 16 endz = center[2] + 16 + 2 * patch allin = [ wat_tmp[startx:endx, starty:endy, startz:endz], fat_tmp[startx:endx, starty:endy, startz:endz] ] if add_dts: allin.append(x_tmp[startx:endx, starty:endy, startz:endz]) allin.append(y_tmp[startx:endx, starty:endy, startz:endz]) allin.append(z_tmp[startx:endx, starty:endy, startz:endz]) allin.append(bd_tmp[startx:endx, starty:endy, startz:endz]) allin = np.stack(allin, axis=0) gt_part = gt_tmp[:, startx:endx, starty:endy, startz:endz] np.save(f"{outpath}/in1/subj{PID}_{idx}_0", allin) np.save(f"{outpath}/gt/subj{PID}_{idx}_0", gt_part) if make_subsampled: startx = startx - 16 endx = endx + 16 starty = starty - 16 endy = endy + 16 startz = startz - 16 endz = endz + 16 allin = [ wat_tmp[startx:endx:3, starty:endy:3, startz:endz:3], fat_tmp[startx:endx:3, starty:endy:3, startz:endz:3] ] if add_dts: allin.append(x_tmp[startx:endx:3, starty:endy:3, startz:endz:3]) allin.append(y_tmp[startx:endx:3, starty:endy:3, startz:endz:3]) allin.append(z_tmp[startx:endx:3, starty:endy:3, startz:endz:3]) allin.append(bd_tmp[startx:endx:3, starty:endy:3, startz:endz:3]) allin = np.stack(allin, axis=0) np.save(f"{outpath}/in2/subj{PID}_{idx}_0", allin) with open(f"{outpath}/datainfo.txt", "w") as info_file: info_file.write(f"""Sliced 3D patches. \nPatch size: {patch_size} \nDTs: {add_dts}\nsubsmpl: {make_subsampled} \nsampling: {sampling}""")
def plotOutput(params, datafolder, pids, doeval=True, take20=None): """After training a net and saving its state to name PARAMS, run inference on subject PID from DATAFOLDER, and plot results+GT. PID should be subject_slice, and all subject_slice_x will be run. E.g. plotOutput('First_unet', 'POEM', '500177_30'). """ #default settings: Arg = { 'network': None, 'n_class': 7, 'in_channels': 2, 'lower_in_channels': 2, 'extractor_net': 'resnet34' } with open(f"RESULTS/{params}_args.txt", "r") as ft: args = ft.read().splitlines() tmpargs = [ i.strip('--').split("=") for i in args if ('--' in i and '=' in i) ] chan1, whichin1 = getchn(args, 'in_chan') chan2, whichin2 = getchn(args, 'lower_in_chan') tmpargs += [['in_channels', chan1], ['lower_in_channels', chan2]] in3D = '--in3D' in args args = dict(tmpargs) #overwrite if given in file: Arg.update(args) device = torch.device('cpu') net = getattr(Networks, Arg['network'])(Arg['in_channels'], Arg['n_class'], Arg['lower_in_channels'], Arg['extractor_net'], in3D) net = net.float() #now we can load learned params: loaded = torch.load(f"RESULTS/{params}", map_location=lambda storage, loc: storage) net.load_state_dict(loaded['state_dict']) if doeval: net.eval() net = net.to(device) #load data> if isinstance(pids, str): pids = [pids] use_in2 = Arg['network'] == 'DeepMedic' allL = 0 allfindgts = [] allfindin1 = [] allfindin2 = [] for pid in pids: findgts = glob.glob(f"./{datafolder}/*/gt/*{pid}*.npy") #findgts = glob.glob(f"./{datafolder}/GTs_2D/*{pid}*.npy") findin1 = glob.glob(f"./{datafolder}/*/in1/*{pid}*.npy") findin2 = glob.glob(f"./{datafolder}/*/in2/*{pid}*.npy") findgts.sort(), findin1.sort(), findin2.sort() # print(findgts) #all subslices in one image. L = len(findgts) if L > 20: #ugly but needed to avoid too long compute if take20 == None: take20 = random.sample(range(L), 20) findgts = [findgts[tk] for tk in take20] findin1 = [findin1[tk] for tk in take20] if use_in2: findin2 = [findin2[tk] for tk in take20] L = len(take20) allL += L allfindgts.extend(findgts) allfindin1.extend(findin1) allfindin2.extend(findin2) organs = [ 'Bckg', 'Bladder', 'KidneyL', 'Liver', 'Pancreas', 'Spleen', 'KidneyR' ] if len(organs) != Arg['n_class']: #in case not POEM dataset used organs = [str(zblj) for zblj in range(Arg['n_class'])] entmpgt = np.load(allfindgts[0]) tgtonehot = entmpgt.shape[0] == 7 #are targets one hot encoded? in3d = tgtonehot * (entmpgt.ndim == 4) + (not tgtonehot) * (entmpgt.ndim == 3) if in3d: #set the right function to use TensorCropping = CenterCropTensor3d else: TensorCropping = CenterCropTensor data = torch.stack([ torch.from_numpy(np.load(i1)).float().to(device) for i1 in allfindin1 ], dim=0) data = [data[:, whichin1, ...]] target = [ flatten_one_hot(np.load(g)) if tgtonehot else np.load(g) for g in allfindgts ] target_oh = torch.stack([ torch.from_numpy(np.load(g)).to(device) if tgtonehot else torch.from_numpy(get_one_hot(np.load(g), 7)).to(device) for g in allfindgts ], dim=0) if use_in2: in2 = torch.stack([ torch.from_numpy(np.load(i2)).float().to(device) for i2 in allfindin2 ], dim=0) data.append(in2[:, whichin2, ...]) out = net(*data) target_oh, out = TensorCropping(target_oh, out) dices = AllDices(out, target_oh) #DicePerClass(out, target_oh) # print((out.shape, target_oh.shape)) outs = [flatten_one_hot(o.detach().squeeze().numpy()) for o in out] fig, ax_tuple = plt.subplots(nrows=allL, ncols=2, figsize=(10, allL * 6 + 1), tight_layout=True) #for compatibility reasons: if ax_tuple.ndim < 2: ax_tuple = ax_tuple[np.newaxis, ...] plt.suptitle(params) for ind in range(len(outs)): #now plot :) targetind, outsind = TensorCropping( target[ind], outs[ind]) #crop to be more comparable # print((outsind.shape, targetind.shape)) if in3d: sl = targetind.shape[-2] // 2 targetind, outsind = targetind[..., sl, :], outsind[..., sl, :] ax1 = ax_tuple[ind, 0] ax1.set_title('GT') ax1.axis('off') ax1.imshow(targetind, cmap='Spectral', vmin=0, vmax=Arg['n_class']) ax2 = ax_tuple[ind, 1] ax2.set_title('OUT') ax2.axis('off') im = ax2.imshow(outsind, cmap='Spectral', vmin=0, vmax=Arg['n_class']) values = np.arange(Arg['n_class']) colors = [im.cmap(im.norm(value)) for value in values] # create a patch (proxy artist) for every color patches = [ mpatches.Patch(color=colors[i], label=organs[i]) for i in range(len(values)) ] # put those patched as legend-handles into the legend ax2.legend(handles=patches, bbox_to_anchor=(1.05, 1.), loc=2, borderaxespad=0.) #write out also Dices: dajci = dices[ind].detach().squeeze().numpy() present_classes = [i for i in range(7) if i in target[ind]] t = ax2.text(1.08, 0.5, 'Dices:', size='medium', horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes) for d in range(7): t = ax2.text(1.1, 0.45 - d * 0.05, f"{organs[d]}: {dajci[d]:.3f}", size='small', transform=ax2.transAxes) plt.show() #plt.savefig('foo.png') #print(dices) return take20