def create_one_3D_file(): ''' Create one big file which contains all 3D Images (not slices). ''' class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 DATASET_FOLDER = "HCP" data_all = [] seg_all = [] print("\n\nProcessing Data...") for s in get_all_subjects(): print("processing data subject {}".format(s)) data = nib.load(join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION) data_all.append(np.array(data)) np.save("data.npy", data_all) del data_all # free memory print("\n\nProcessing Segs...") for s in get_all_subjects(): print("processing seg subject {}".format(s)) seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION) seg_all.append(np.array(seg)) print("SEG TYPE: {}".format(seg_all.dtype)) np.save("seg.npy", seg_all)
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load( join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, : -1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load( join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape( seg, HP.DATASET, HP.RESOLUTION) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True): mask_dir = join(C.HOME, HP.DATASET_FOLDER) input_dir = HP.MULTI_PARENT_PATH combined_slices = [] mask_slices = [] for s in subjects: print("processing subject {}".format(s)) probs_x = nib.load(join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_y = nib.load(join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_z = nib.load(join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION) # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION) # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION) combined = np.stack((probs_x, probs_y, probs_z), axis=4) # (73, 87, 73, 18, 3) #not working alone: one dim too much for UNet -> reshape combined = np.reshape(combined, (combined.shape[0], combined.shape[1], combined.shape[2], combined.shape[3] * combined.shape[4])) # (73, 87, 73, 3*18) # print("combined shape after", combined.shape) mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.DATASET == "HCP_2mm": #use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION) elif HP.DATASET == "HCP_2.5mm": # use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION) else: # Mask has same resolution as probmaps -> we can use same resizing mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION) # Save as Img img = nib.Nifti1Image(combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz")) combined = DatasetUtils.scale_input_to_unet_shape(combined, HP.DATASET, HP.RESOLUTION) assert (combined.shape[2] == mask_data.shape[2]) #Save as Slices for z in range(combined.shape[2]): combined_slices.append(combined[:, :, z, :]) mask_slices.append(mask_data[:, :, z, :]) if shuffle: combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9) if HP.TRAIN: np.save(filename + "_data.npy", combined_slices) np.save(filename + "_seg.npy", mask_slices)
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, :-1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def create_one_3D_file(): ''' Create one big file which contains all 3D Images (not slices). ''' class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 DATASET_FOLDER = "HCP" data_all = [] seg_all = [] print("\n\nProcessing Data...") for s in get_all_subjects(): print("processing data subject {}".format(s)) data = nib.load( join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, HP.DATASET, HP.RESOLUTION) data_all.append(np.array(data)) np.save("data.npy", data_all) del data_all # free memory print("\n\nProcessing Segs...") for s in get_all_subjects(): print("processing seg subject {}".format(s)) seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape( seg, HP.DATASET, HP.RESOLUTION) seg_all.append(np.array(seg)) print("SEG TYPE: {}".format(seg_all.dtype)) np.save("seg.npy", seg_all)
def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0, 2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw - 1) / 2) data_pad = np.zeros( (data.shape[0] + sw - 1, data.shape[1] + sw - 1, data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype) data_pad[ pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch = [] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 0, 3, 1, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 1, 3, 0, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 2, 3, 0, 1 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = { "data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0, 2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype( np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose( 0, 3, 1, 2 ) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype( np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype( np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = { "data": x, # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def get_batches(self, batch_size=1): num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) if self.HP.TYPE == "combined": # Load from Npy file for Fusion data = self.subject seg = [] nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0] num_batches = int(nr_of_samples / batch_size / num_processes) batch_gen = SlicesBatchGeneratorNpyImg_fusion( (data, seg), BATCH_SIZE=batch_size, num_batches=num_batches, seed=None) else: # Load Features if self.HP.FEATURES_FILENAME == "12g90g270g": data_img = nib.load( join(self.data_dir, "270g_125mm_peaks.nii.gz")) else: data_img = nib.load( join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz")) data = data_img.get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm") #If we want to test HCP_32g on HighRes net #Load Segmentation if self.use_gt_mask: seg = nib.load( join(self.data_dir, self.HP.LABELS_FILENAME + ".nii.gz")).get_data() if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080", "bundle_peaks_Part1_808080", "bundle_peaks_Part2_808080", "bundle_peaks_Part3_808080", "bundle_peaks_Part4_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) else: # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros( (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) batch_gen = SlicesBatchGenerator((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False)) if self.HP.TEST_TIME_DAUG: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.), do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) # tfs.append(ResampleTransform(zoom_range=(0.5, 1))) # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) tfs.append( ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False)) tfs.append( BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter( batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None ) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def get_batches(self, batch_size=1): num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) if self.HP.TYPE == "combined": # Load from Npy file for Fusion data = self.subject seg = [] nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0] num_batches = int(nr_of_samples / batch_size / num_processes) batch_gen = SlicesBatchGeneratorNpyImg_fusion((data, seg), BATCH_SIZE=batch_size, num_batches=num_batches, seed=None) else: # Load Features if self.HP.FEATURES_FILENAME == "12g90g270g": data_img = nib.load(join(self.data_dir, "270g_125mm_peaks.nii.gz")) else: data_img = nib.load(join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz")) data = data_img.get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm") #If we want to test HCP_32g on HighRes net #Load Segmentation if self.use_gt_mask: seg = nib.load(join(self.data_dir, self.HP.LABELS_FILENAME + ".nii.gz")).get_data() if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080"]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) else: # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros((self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False)) if self.HP.TEST_TIME_DAUG: center_dist_from_border = int(self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransform(self.HP.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.), do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) # tfs.append(ResampleTransform(zoom_range=(0.5, 1))) # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) tfs.append(ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False)) tfs.append(BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter(batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def _create_slices_file(HP, subjects, filename, slice, shuffle=True): data_dir = join(C.HOME, HP.DATASET_FOLDER) dwi_slices = [] mask_slices = [] print("\n\nProcessing Data...") for s in subjects: print("processing dwi subject {}".format(s)) dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz")) dwi_data = dwi.get_data() dwi_data = np.nan_to_num(dwi_data) dwi_data = DatasetUtils.scale_input_to_unet_shape(dwi_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # dwi_slices.append(dwi_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # dwi_slices.append(dwi_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # dwi_slices.append(dwi_data[:, :, z, :]) #Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): dwi_slices.append(dwi_data[z, :, :, :]) for z in range(dwi_data.shape[1]): dwi_slices.append(dwi_data[:, z, :, :]) for z in range(dwi_data.shape[2]): dwi_slices.append(dwi_data[:, :, z, :]) dwi_slices = np.array(dwi_slices) random_idxs = None if shuffle: random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices)) dwi_slices = dwi_slices[random_idxs] np.save(filename + "_data.npy", dwi_slices) del dwi_slices #free memory print("\n\nProcessing Segs...") for s in subjects: print("processing seg subject {}".format(s)) mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5) mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # mask_slices.append(mask_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # mask_slices.append(mask_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # mask_slices.append(mask_data[:, :, z, :]) # Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): mask_slices.append(mask_data[z, :, :, :]) for z in range(dwi_data.shape[1]): mask_slices.append(mask_data[:, z, :, :]) for z in range(dwi_data.shape[2]): mask_slices.append(mask_data[:, :, z, :]) mask_slices = np.array(mask_slices) print("SEG TYPE: {}".format(mask_slices.dtype)) if shuffle: mask_slices = mask_slices[random_idxs] np.save(filename + "_seg.npy", mask_slices)
def _create_slices_file(HP, subjects, filename, slice, shuffle=True): data_dir = join(C.HOME, HP.DATASET_FOLDER) dwi_slices = [] mask_slices = [] print("\n\nProcessing Data...") for s in subjects: print("processing dwi subject {}".format(s)) dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz")) dwi_data = dwi.get_data() dwi_data = np.nan_to_num(dwi_data) dwi_data = DatasetUtils.scale_input_to_unet_shape( dwi_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # dwi_slices.append(dwi_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # dwi_slices.append(dwi_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # dwi_slices.append(dwi_data[:, :, z, :]) #Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): dwi_slices.append(dwi_data[z, :, :, :]) for z in range(dwi_data.shape[1]): dwi_slices.append(dwi_data[:, z, :, :]) for z in range(dwi_data.shape[2]): dwi_slices.append(dwi_data[:, :, z, :]) dwi_slices = np.array(dwi_slices) random_idxs = None if shuffle: random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices)) dwi_slices = dwi_slices[random_idxs] np.save(filename + "_data.npy", dwi_slices) del dwi_slices #free memory print("\n\nProcessing Segs...") for s in subjects: print("processing seg subject {}".format(s)) mask_data = ImgUtils.create_multilabel_mask( HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5) mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # mask_slices.append(mask_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # mask_slices.append(mask_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # mask_slices.append(mask_data[:, :, z, :]) # Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): mask_slices.append(mask_data[z, :, :, :]) for z in range(dwi_data.shape[1]): mask_slices.append(mask_data[:, z, :, :]) for z in range(dwi_data.shape[2]): mask_slices.append(mask_data[:, :, z, :]) mask_slices = np.array(mask_slices) print("SEG TYPE: {}".format(mask_slices.dtype)) if shuffle: mask_slices = mask_slices[random_idxs] np.save(filename + "_seg.npy", mask_slices)
def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True): mask_dir = join(C.HOME, HP.DATASET_FOLDER) input_dir = HP.MULTI_PARENT_PATH combined_slices = [] mask_slices = [] for s in subjects: print("processing subject {}".format(s)) probs_x = nib.load( join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_y = nib.load( join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_z = nib.load( join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION) # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION) # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION) combined = np.stack( (probs_x, probs_y, probs_z), axis=4 ) # (73, 87, 73, 18, 3) #not working alone: one dim too much for UNet -> reshape combined = np.reshape( combined, (combined.shape[0], combined.shape[1], combined.shape[2], combined.shape[3] * combined.shape[4])) # (73, 87, 73, 3*18) # print("combined shape after", combined.shape) mask_data = ImgUtils.create_multilabel_mask( HP, s, labels_type=HP.LABELS_TYPE) if HP.DATASET == "HCP_2mm": #use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, "HCP", HP.RESOLUTION) elif HP.DATASET == "HCP_2.5mm": # use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, "HCP", HP.RESOLUTION) else: # Mask has same resolution as probmaps -> we can use same resizing mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, HP.DATASET, HP.RESOLUTION) # Save as Img img = nib.Nifti1Image( combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz")) combined = DatasetUtils.scale_input_to_unet_shape( combined, HP.DATASET, HP.RESOLUTION) assert (combined.shape[2] == mask_data.shape[2]) #Save as Slices for z in range(combined.shape[2]): combined_slices.append(combined[:, :, z, :]) mask_slices.append(mask_data[:, :, z, :]) if shuffle: combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9) if HP.TRAIN: np.save(filename + "_data.npy", combined_slices) np.save(filename + "_seg.npy", mask_slices)
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0,2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw-1) / 2) data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype) data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch=[] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(0, 3, 1, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(1, 3, 0, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(2, 3, 0, 1) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = {"data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080"]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0,2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype(np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(0, 3, 1, 2) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype(np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype(np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = {"data": x, # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict