def create_one_3D_file(): ''' Create one big file which contains all 3D Images (not slices). ''' class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 DATASET_FOLDER = "HCP" data_all = [] seg_all = [] print("\n\nProcessing Data...") for s in get_all_subjects(): print("processing data subject {}".format(s)) data = nib.load(join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION) data_all.append(np.array(data)) np.save("data.npy", data_all) del data_all # free memory print("\n\nProcessing Segs...") for s in get_all_subjects(): print("processing seg subject {}".format(s)) seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION) seg_all.append(np.array(seg)) print("SEG TYPE: {}".format(seg_all.dtype)) np.save("seg.npy", seg_all)
def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"): ''' Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape :param img4d: (x, y, z, userdefined) (userdefined could be gradients or classes) :param resolution: "1.25mm" / "2mm" / "2.5mm" results in UNet input shape of (144,144,144) or (80,80,80) :return: img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none) (note: 2.5mm padded with more zeros to reach 80,80,80) ''' if resolution == "1.25mm": if dataset == "HCP": # (145,174,145) # no resize needed return img4d[1:, 15:159, 1:] # (144,144,144) elif dataset == "HCP_32g": # (73,87,73) # return img4d[1:, 15:159, 1:] # (144,144,144) #OLD when HCP_32g was still 125mm img4d = ImgUtils.resize_first_three_dims(img4d, zoom=2) # (146,174,146,none) img4d = img4d[:-1,:,:-1] #remove one voxel that came from upsampling #(145,174,145) return img4d[1:, 15:159, 1:] # (144,144,144) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'") elif resolution == "2mm": if dataset == "HCP": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62) # (90,108,90) return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "HCP_32g": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62) # (90,108,90) return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "HCP_2mm": # (90,108,90) # no resize needed return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '2mm' not supported for dataset 'TRACED'") elif resolution == "2.5mm": if dataset == "HCP": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.5) # (73,87,73,none) bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "HCP_2.5mm": # (73,87,73,none) #no resize needed bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "HCP_32g": # (73,87,73,none) bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0, 0, 0, :] # make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "TRACED": # (78,93,75) # no resize needed bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0, 0, 0, :] # make bg have same value as bg from original img bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :] return bg # (80,80,80)
def precompute_batches(custom_type=None): ''' 9000 slices per epoch -> 200 batches (batchsize=44) per epoch => 200-1000 batches needed 270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ 270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ ''' class HP: NORMALIZE_DATA = True DATA_AUGMENTATION = False CV_FOLD = 0 INPUT_DIM = (144, 144) BATCH_SIZE = 44 DATASET_FOLDER = "HCP" TYPE = "single_direction" EXP_PATH = "~" LABELS_FILENAME = "bundle_peaks" FEATURES_FILENAME = "270g_125mm_peaks" DATASET = "HCP" RESOLUTION = "1.25mm" LABELS_TYPE = np.float32 HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(HP.CV_FOLD) num_batches_base = 5000 num_batches = { "train": num_batches_base, "validate": int(num_batches_base / 3.), "test": int(num_batches_base / 3.), } if custom_type is None: types = ["train", "validate", "test"] else: types = [custom_type] for type in types: dataManager = DataManagerTrainingNiftiImgs(HP) batch_gen = dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type]) for idx, batch in enumerate(batch_gen): print("Processing: {}".format(idx)) # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y" # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ" DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ" ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type)) data = nib.Nifti1Image(batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz")) seg = nib.Nifti1Image(batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True): mask_dir = join(C.HOME, HP.DATASET_FOLDER) input_dir = HP.MULTI_PARENT_PATH combined_slices = [] mask_slices = [] for s in subjects: print("processing subject {}".format(s)) probs_x = nib.load(join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_y = nib.load(join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_z = nib.load(join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION) # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION) # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION) combined = np.stack((probs_x, probs_y, probs_z), axis=4) # (73, 87, 73, 18, 3) #not working alone: one dim too much for UNet -> reshape combined = np.reshape(combined, (combined.shape[0], combined.shape[1], combined.shape[2], combined.shape[3] * combined.shape[4])) # (73, 87, 73, 3*18) # print("combined shape after", combined.shape) mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.DATASET == "HCP_2mm": #use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION) elif HP.DATASET == "HCP_2.5mm": # use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION) else: # Mask has same resolution as probmaps -> we can use same resizing mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION) # Save as Img img = nib.Nifti1Image(combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz")) combined = DatasetUtils.scale_input_to_unet_shape(combined, HP.DATASET, HP.RESOLUTION) assert (combined.shape[2] == mask_data.shape[2]) #Save as Slices for z in range(combined.shape[2]): combined_slices.append(combined[:, :, z, :]) mask_slices.append(mask_data[:, :, z, :]) if shuffle: combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9) if HP.TRAIN: np.save(filename + "_data.npy", combined_slices) np.save(filename + "_seg.npy", mask_slices)
def move_to_MNI_space(input_file, bvals, bvecs, brain_mask, output_dir): print("Moving input to MNI space...") os.system("calc_FA -i " + input_file + " -o " + output_dir + "/FA.nii.gz --bvals " + bvals + " --bvecs " + bvecs + " --brain_mask " + brain_mask) dwi_spacing = ImgUtils.get_image_spacing(input_file) template_path = resource_filename('examples.resources', 'MNI_FA_template.nii.gz') os.system( "flirt -ref " + template_path + " -in " + output_dir + "/FA.nii.gz -out " + output_dir + "/FA_MNI.nii.gz -omat " + output_dir + "/FA_2_MNI.mat -dof 6 -cost mutualinfo -searchcost mutualinfo") os.system("flirt -ref " + template_path + " -in " + input_file + " -out " + output_dir + "/Diffusion_MNI.nii.gz -applyisoxfm " + dwi_spacing + " -init " + output_dir + "/FA_2_MNI.mat -dof 6") os.system("cp " + bvals + " " + output_dir + "/Diffusion_MNI.bvals") os.system("cp " + bvecs + " " + output_dir + "/Diffusion_MNI.bvecs") new_input_file = join(output_dir, "Diffusion_MNI.nii.gz") bvecs = join(output_dir, "Diffusion_MNI.bvecs") bvals = join(output_dir, "Diffusion_MNI.bvals") brain_mask = Mrtrix.create_brain_mask(new_input_file, output_dir) return new_input_file, bvals, bvecs, brain_mask
def cut_and_scale_img_back_to_original_img(data, t): ''' Undo the transformations done with pad_and_scale_img_to_square_img data: 3D or 4D image t: transformation dict ''' # Back to old size # use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling if len(data.shape) == 3: new_data = ndimage.zoom(data, (1. / t["zoom"]), order=0) elif len(data.shape) == 4: new_data = ImgUtils.resize_first_three_dims(data, order=0, zoom=(1. / t["zoom"])) x_residual = 0 y_residual = 0 z_residual = 0 # check if has 0.5 residual -> we have to cut 1 pixel more at the end if t["pad_x"] - int(t["pad_x"]) == 0.5: x_residual = 1 if t["pad_y"] - int(t["pad_y"]) == 0.5: y_residual = 1 if t["pad_z"] - int(t["pad_z"]) == 0.5: z_residual = 1 # Cut padding shape = new_data.shape new_data = new_data[int(t["pad_x"]): shape[0] - int(t["pad_x"]) - x_residual, int(t["pad_y"]): shape[1] - int(t["pad_y"]) - y_residual, int(t["pad_z"]): shape[2] - int(t["pad_z"]) - z_residual] return new_data
def cut_and_scale_img_back_to_original_img(data, t): ''' Undo the transformations done with pad_and_scale_img_to_square_img data: 3D or 4D image t: transformation dict ''' nr_dims = len(data.shape) assert (nr_dims >= 3 and nr_dims <= 4) # Back to old size # use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling if nr_dims == 3: new_data = ndimage.zoom(data, (1. / t["zoom"]), order=0) elif nr_dims == 4: new_data = ImgUtils.resize_first_three_dims(data, order=0, zoom=(1. / t["zoom"])) x_residual = 0 y_residual = 0 z_residual = 0 # check if has 0.5 residual -> we have to cut 1 pixel more at the end if t["pad_x"] - int(t["pad_x"]) == 0.5: x_residual = 1 if t["pad_y"] - int(t["pad_y"]) == 0.5: y_residual = 1 if t["pad_z"] - int(t["pad_z"]) == 0.5: z_residual = 1 # Cut padding shape = new_data.shape new_data = new_data[int(t["pad_x"]): shape[0] - int(t["pad_x"]) - x_residual, int(t["pad_y"]): shape[1] - int(t["pad_y"]) - y_residual, int(t["pad_z"]): shape[2] - int(t["pad_z"]) - z_residual] return new_data
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load( join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, : -1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load( join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape( seg, HP.DATASET, HP.RESOLUTION) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def move_to_subject_space(output_dir): print("Moving input to subject space...") file_path_in = output_dir + "/bundle_segmentations.nii.gz" file_path_out = output_dir + "/bundle_segmentations_subjectSpace.nii.gz" dwi_spacing = ImgUtils.get_image_spacing(file_path_in) os.system("convert_xfm -omat " + output_dir + "/MNI_2_FA.mat -inverse " + output_dir + "/FA_2_MNI.mat") os.system("flirt -ref " + output_dir + "/FA.nii.gz -in " + file_path_in + " -out " + file_path_out + " -applyisoxfm " + dwi_spacing + " -init " + output_dir + "/MNI_2_FA.mat -dof 6") os.system("fslmaths " + file_path_out + " -thr 0.5 -bin " + file_path_out)
def create_one_3D_file(): ''' Create one big file which contains all 3D Images (not slices). ''' class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 DATASET_FOLDER = "HCP" data_all = [] seg_all = [] print("\n\nProcessing Data...") for s in get_all_subjects(): print("processing data subject {}".format(s)) data = nib.load( join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, HP.DATASET, HP.RESOLUTION) data_all.append(np.array(data)) np.save("data.npy", data_all) del data_all # free memory print("\n\nProcessing Segs...") for s in get_all_subjects(): print("processing seg subject {}".format(s)) seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape( seg, HP.DATASET, HP.RESOLUTION) seg_all.append(np.array(seg)) print("SEG TYPE: {}".format(seg_all.dtype)) np.save("seg.npy", seg_all)
def scale_input_to_world_shape(img4d, dataset, resolution="1.25mm"): ''' Scale input image to original resolution and pad/cut image to make it original size :param img4d: (x, y, z, userdefined) (userdefined could be gradients or classes) :param resolution: "1.25mm" / "2mm" / "2.5mm" :return: img with original size ''' if resolution == "1.25mm": if dataset == "HCP": # (144,144,144) # no resize needed return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0) # (146, 174, 146, none) elif dataset == "HCP_32g": # (144,144,144) # no resize needed return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0) # (146, 174, 146, none) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'") elif resolution == "2mm": if dataset == "HCP": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "HCP_32g": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "HCP_2mm": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '2mm' not supported for dataset 'TRACED'") elif resolution == "2.5mm": if dataset == "HCP": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77,:,4:77, :] # (73, 87, 73, none) elif dataset == "HCP_2.5mm": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77,:,4:77,:] # (73, 87, 73, none) elif dataset == "HCP_32g": # ((80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]), [80, 87, 80, img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77, :, 4:77, :] # (73, 87, 73, none) elif dataset == "TRACED": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,7,0,0]), [80,93,80,img4d.shape[3]],pad_value=0) # (80,93,80,none) return img4d[1:79, :, 3:78, :] # (78,93,75,none)
def pad_and_scale_img_to_square_img(data, target_size=144): ''' Expects 3D or 4D image as input. Does 1. Pad image with 0 to make it square (if uneven padding -> adds one more px "behind" img; but resulting img shape will be correct) 2. Scale image to UNet size (144, 144, 144) ''' nr_dims = len(data.shape) assert (nr_dims >= 3 and nr_dims <= 4) shape = data.shape biggest_dim = max(shape) # Pad to make square if nr_dims == 4: new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim, shape[3])).astype(data.dtype) else: new_img = np.zeros( (biggest_dim, biggest_dim, biggest_dim)).astype(data.dtype) pad1 = (biggest_dim - shape[0]) / 2. pad2 = (biggest_dim - shape[1]) / 2. pad3 = (biggest_dim - shape[2]) / 2. new_img[int(pad1):int(pad1) + shape[0], int(pad2):int(pad2) + shape[1], int(pad3):int(pad3) + shape[2]] = data # Scale to right size zoom = float(target_size) / biggest_dim if nr_dims == 4: #use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling new_img = ImgUtils.resize_first_three_dims(new_img, order=0, zoom=zoom) else: new_img = ndimage.zoom(new_img, zoom, order=0) transformation = { "original_shape": shape, "pad_x": pad1, "pad_y": pad2, "pad_z": pad3, "zoom": zoom } return new_img, transformation
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, :-1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def pad_and_scale_img_to_square_img(data, target_size=144): ''' Expects 3D or 4D image as input. Does 1. Pad image with 0 to make it square (if uneven padding -> adds one more px "behind" img; but resulting img shape will be correct) 2. Scale image to UNet size (144, 144, 144) ''' nr_dims = len(data.shape) assert (nr_dims >= 3 and nr_dims <= 4) shape = data.shape biggest_dim = max(shape) # Pad to make square if nr_dims == 4: new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim, shape[3])).astype(data.dtype) else: new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim)).astype(data.dtype) pad1 = (biggest_dim - shape[0]) / 2. pad2 = (biggest_dim - shape[1]) / 2. pad3 = (biggest_dim - shape[2]) / 2. new_img[int(pad1):int(pad1) + shape[0], int(pad2):int(pad2) + shape[1], int(pad3):int(pad3) + shape[2]] = data # Scale to right size zoom = float(target_size) / biggest_dim if nr_dims == 4: #use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling new_img = ImgUtils.resize_first_three_dims(new_img, order=0, zoom=zoom) else: new_img = ndimage.zoom(new_img, zoom, order=0) transformation = { "original_shape": shape, "pad_x": pad1, "pad_y": pad2, "pad_z": pad3, "zoom": zoom } return new_img, transformation
def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"): ''' Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape :param img4d: (x, y, z, userdefined) (userdefined could be gradients or classes) :param resolution: "1.25mm" / "2mm" / "2.5mm" results in UNet input shape of (144,144,144) or (80,80,80) :return: img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none) (note: 2.5mm padded with more zeros to reach 80,80,80) ''' if resolution == "1.25mm": if dataset == "HCP": # (145,174,145) # no resize needed return img4d[1:, 15:159, 1:] # (144,144,144) elif dataset == "HCP_32g": # (73,87,73) # return img4d[1:, 15:159, 1:] # (144,144,144) #OLD when HCP_32g was still 125mm img4d = ImgUtils.resize_first_three_dims(img4d, zoom=2) # (146,174,146,none) img4d = img4d[:-1,:,:-1] #remove one voxel that came from upsampling #(145,174,145) return img4d[1:, 15:159, 1:] # (144,144,144) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'") elif dataset == "Schizo": # (91,109,91) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=1.60) # (146,174,146) return img4d[1:145, 15:159, 1:145] # (144,144,144) elif resolution == "2mm": if dataset == "HCP": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62) # (90,108,90) return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "HCP_32g": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62) # (90,108,90) return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "HCP_2mm": # (90,108,90) # no resize needed return img4d[5:85, 14:94, 5:85, :] # (80,80,80) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '2mm' not supported for dataset 'TRACED'") elif dataset == "Schizo": # (91,109,91) return img4d[:, 9:100, :] # (91,91,91) elif resolution == "2.5mm": if dataset == "HCP": # (145,174,145) img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.5) # (73,87,73,none) bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "HCP_2.5mm": # (73,87,73,none) #no resize needed bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "HCP_32g": # (73,87,73,none) bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0, 0, 0, :] # make bg have same value as bg from original img (this adds last dim of img4d to last dim of bg) bg[4:77, :, 4:77] = img4d[:, 4:84, :, :] return bg # (80,80,80) elif dataset == "TRACED": # (78,93,75) # no resize needed bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype) bg = bg + img4d[0, 0, 0, :] # make bg have same value as bg from original img bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :] return bg # (80,80,80)
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3) # set lower for more sensitivity #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def scale_input_to_world_shape(img4d, dataset, resolution="1.25mm"): ''' Scale input image to original resolution and pad/cut image to make it original size :param img4d: (x, y, z, userdefined) (userdefined could be gradients or classes) :param resolution: "1.25mm" / "2mm" / "2.5mm" :return: img with original size ''' if resolution == "1.25mm": if dataset == "HCP": # (144,144,144) # no resize needed return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0) # (146, 174, 146, none) elif dataset == "HCP_32g": # (144,144,144) # no resize needed return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0) # (146, 174, 146, none) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'") elif dataset == "Schizo": # (144,144,144) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [145,174,145,img4d.shape[3]], pad_value=0) # (145, 174, 145, none) return ImgUtils.resize_first_three_dims(img4d, zoom=0.62) # (91,109,91) elif resolution == "2mm": if dataset == "HCP": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "HCP_32g": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "HCP_2mm": # (80,80,80) return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0) # (90, 108, 90, none) elif dataset == "TRACED": # (78,93,75) raise ValueError("resolution '2mm' not supported for dataset 'TRACED'") elif resolution == "2.5mm": if dataset == "HCP": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77,:,4:77, :] # (73, 87, 73, none) elif dataset == "HCP_2.5mm": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77,:,4:77,:] # (73, 87, 73, none) elif dataset == "HCP_32g": # ((80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]), [80, 87, 80, img4d.shape[3]], pad_value=0) # (80,87,80,none) return img4d[4:77, :, 4:77, :] # (73, 87, 73, none) elif dataset == "TRACED": # (80,80,80) img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,7,0,0]), [80,93,80,img4d.shape[3]],pad_value=0) # (80,93,80,none) return img4d[1:79, :, 3:78, :] # (78,93,75,none)
def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True): mask_dir = join(C.HOME, HP.DATASET_FOLDER) input_dir = HP.MULTI_PARENT_PATH combined_slices = [] mask_slices = [] for s in subjects: print("processing subject {}".format(s)) probs_x = nib.load( join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_y = nib.load( join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() probs_z = nib.load( join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data() # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION) # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION) # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION) combined = np.stack( (probs_x, probs_y, probs_z), axis=4 ) # (73, 87, 73, 18, 3) #not working alone: one dim too much for UNet -> reshape combined = np.reshape( combined, (combined.shape[0], combined.shape[1], combined.shape[2], combined.shape[3] * combined.shape[4])) # (73, 87, 73, 3*18) # print("combined shape after", combined.shape) mask_data = ImgUtils.create_multilabel_mask( HP, s, labels_type=HP.LABELS_TYPE) if HP.DATASET == "HCP_2mm": #use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, "HCP", HP.RESOLUTION) elif HP.DATASET == "HCP_2.5mm": # use "HCP" because for mask we need downscaling mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, "HCP", HP.RESOLUTION) else: # Mask has same resolution as probmaps -> we can use same resizing mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, HP.DATASET, HP.RESOLUTION) # Save as Img img = nib.Nifti1Image( combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz")) combined = DatasetUtils.scale_input_to_unet_shape( combined, HP.DATASET, HP.RESOLUTION) assert (combined.shape[2] == mask_data.shape[2]) #Save as Slices for z in range(combined.shape[2]): combined_slices.append(combined[:, :, z, :]) mask_slices.append(mask_data[:, :, z, :]) if shuffle: combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9) if HP.TRAIN: np.save(filename + "_data.npy", combined_slices) np.save(filename + "_seg.npy", mask_slices)
def track(bundle, peaks, output_dir, brain_mask, filter_by_endpoints=False): ''' :param bundle: Bundle name :param output_dir: :param brain_mask: :param filter_by_endpoints: use results of endings_segmentation to filter out all fibers not endings in those regions :return: ''' tracking_folder = "TOM_trackings" # tracking_folder = "TOM_trackings_FiltEP6_FiltMask3" smooth = None # None / 10 TOM_folder = "TOM" # TOM / TOM_thr1 tmp_dir = tempfile.mkdtemp() os.system("export PATH=/code/mrtrix3/bin:$PATH") os.system("mkdir -p " + output_dir + "/" + tracking_folder) if filter_by_endpoints: beginnings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz").get_data().max() > 0 endings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz").get_data().max() > 0 if not beginnings_mask_ok: print("WARNING: tract beginnings mask of {} empty".format(bundle)) if not endings_mask_ok: print("WARNING: tract endings mask of {} empty".format(bundle)) if filter_by_endpoints and beginnings_mask_ok and endings_mask_ok: # dilation = 2 # dilation has to be quite high, because endings sometimes almost completely missing ImgUtils.dilate_binary_mask(output_dir + "/bundle_segmentations/" + bundle + ".nii.gz", tmp_dir + "/" + bundle + ".nii.gz", dilation=3) ImgUtils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz", tmp_dir + "/" + bundle + "_e.nii.gz", dilation=6) ImgUtils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz", tmp_dir + "/" + bundle + "_b.nii.gz", dilation=6) os.system("tckgen -algorithm FACT " + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + " -seed_image " + brain_mask + " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + " -minlength 40 -select 2000 -force -quiet") # #Probabilistic Tracking without TOM # os.system("tckgen -algorithm iFOD2 " + # peaks + " " + # output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + # " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + # " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + # " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + # " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + # " -minlength 40 -seeds 200000 -select 2000 -force") else: os.system("tckgen -algorithm FACT " + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + " -seed_image " + brain_mask + " -minlength 40 -select 2000 -force -quiet") reference_affine = nib.load(brain_mask).get_affine() FiberUtils.convert_tck_to_trk(output_dir + "/" + tracking_folder + "/" + bundle + ".tck", output_dir + "/" + tracking_folder + "/" + bundle + ".trk", reference_affine, smooth=smooth) os.system("rm -f " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck") shutil.rmtree(tmp_dir)
def precompute_batches(custom_type=None): ''' 9000 slices per epoch -> 200 batches (batchsize=44) per epoch => 200-1000 batches needed 270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ 270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ ''' class HP: NORMALIZE_DATA = True DATA_AUGMENTATION = False CV_FOLD = 0 INPUT_DIM = (144, 144) BATCH_SIZE = 44 DATASET_FOLDER = "HCP" TYPE = "single_direction" EXP_PATH = "~" LABELS_FILENAME = "bundle_peaks" FEATURES_FILENAME = "270g_125mm_peaks" DATASET = "HCP" RESOLUTION = "1.25mm" LABELS_TYPE = np.float32 HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold( HP.CV_FOLD) num_batches_base = 5000 num_batches = { "train": num_batches_base, "validate": int(num_batches_base / 3.), "test": int(num_batches_base / 3.), } if custom_type is None: types = ["train", "validate", "test"] else: types = [custom_type] for type in types: dataManager = DataManagerTrainingNiftiImgs(HP) batch_gen = dataManager.get_batches( batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type]) for idx, batch in enumerate(batch_gen): print("Processing: {}".format(idx)) # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y" # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ" DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ" ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type)) data = nib.Nifti1Image( batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz")) seg = nib.Nifti1Image( batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
def track(bundle, peaks, output_dir, brain_mask, filter_by_endpoints=False, output_format="trk"): ''' :param bundle: Bundle name :param output_dir: :param brain_mask: :param filter_by_endpoints: use results of endings_segmentation to filter out all fibers not endings in those regions :return: ''' tracking_folder = "TOM_trackings" # tracking_folder = "TOM_trackings_FiltEP6_FiltMask3" smooth = None # None / 10 TOM_folder = "TOM" # TOM / TOM_thr1 tmp_dir = tempfile.mkdtemp() os.system("export PATH=/code/mrtrix3/bin:$PATH") os.system("mkdir -p " + output_dir + "/" + tracking_folder) if filter_by_endpoints: bundle_mask_ok = nib.load(output_dir + "/bundle_segmentations/" + bundle + ".nii.gz").get_data().max() > 0 beginnings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz").get_data().max() > 0 endings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz").get_data().max() > 0 if not bundle_mask_ok: print( "WARNING: tract mask of {} empty. Falling back to tracking without filtering by endpoints." .format(bundle)) if not beginnings_mask_ok: print( "WARNING: tract beginnings mask of {} empty. Falling back to tracking without filtering by endpoints." .format(bundle)) if not endings_mask_ok: print( "WARNING: tract endings mask of {} empty. Falling back to tracking without filtering by endpoints." .format(bundle)) if filter_by_endpoints and bundle_mask_ok and beginnings_mask_ok and endings_mask_ok: # dilation has to be quite high, because endings sometimes almost completely missing ImgUtils.dilate_binary_mask(output_dir + "/bundle_segmentations/" + bundle + ".nii.gz", tmp_dir + "/" + bundle + ".nii.gz", dilation=3) ImgUtils.dilate_binary_mask( output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz", tmp_dir + "/" + bundle + "_e.nii.gz", dilation=6) ImgUtils.dilate_binary_mask( output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz", tmp_dir + "/" + bundle + "_b.nii.gz", dilation=6) os.system("tckgen -algorithm FACT " + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + " -seed_image " + brain_mask + " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + " -minlength 40 -select 2000 -force -quiet") # #Probabilistic Tracking without TOM # os.system("tckgen -algorithm iFOD2 " + # peaks + " " + # output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + # " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" + # " -mask " + tmp_dir + "/" + bundle + ".nii.gz" + # " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" + # " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" + # " -minlength 40 -seeds 200000 -select 2000 -force") else: os.system("tckgen -algorithm FACT " + output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck" + " -seed_image " + brain_mask + " -minlength 40 -select 2000 -force -quiet") if output_format == "trk": reference_affine = nib.load(brain_mask).get_affine() FiberUtils.convert_tck_to_trk( output_dir + "/" + tracking_folder + "/" + bundle + ".tck", output_dir + "/" + tracking_folder + "/" + bundle + ".trk", reference_affine, smooth=smooth) os.system("rm -f " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck") shutil.rmtree(tmp_dir)
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :param peak_threshold: all peaks shorter than peak_threshold will be set to zero :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def _create_slices_file(HP, subjects, filename, slice, shuffle=True): data_dir = join(C.HOME, HP.DATASET_FOLDER) dwi_slices = [] mask_slices = [] print("\n\nProcessing Data...") for s in subjects: print("processing dwi subject {}".format(s)) dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz")) dwi_data = dwi.get_data() dwi_data = np.nan_to_num(dwi_data) dwi_data = DatasetUtils.scale_input_to_unet_shape(dwi_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # dwi_slices.append(dwi_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # dwi_slices.append(dwi_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # dwi_slices.append(dwi_data[:, :, z, :]) #Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): dwi_slices.append(dwi_data[z, :, :, :]) for z in range(dwi_data.shape[1]): dwi_slices.append(dwi_data[:, z, :, :]) for z in range(dwi_data.shape[2]): dwi_slices.append(dwi_data[:, :, z, :]) dwi_slices = np.array(dwi_slices) random_idxs = None if shuffle: random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices)) dwi_slices = dwi_slices[random_idxs] np.save(filename + "_data.npy", dwi_slices) del dwi_slices #free memory print("\n\nProcessing Segs...") for s in subjects: print("processing seg subject {}".format(s)) mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5) mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # mask_slices.append(mask_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # mask_slices.append(mask_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # mask_slices.append(mask_data[:, :, z, :]) # Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): mask_slices.append(mask_data[z, :, :, :]) for z in range(dwi_data.shape[1]): mask_slices.append(mask_data[:, z, :, :]) for z in range(dwi_data.shape[2]): mask_slices.append(mask_data[:, :, z, :]) mask_slices = np.array(mask_slices) print("SEG TYPE: {}".format(mask_slices.dtype)) if shuffle: mask_slices = mask_slices[random_idxs] np.save(filename + "_seg.npy", mask_slices)
def _create_slices_file(HP, subjects, filename, slice, shuffle=True): data_dir = join(C.HOME, HP.DATASET_FOLDER) dwi_slices = [] mask_slices = [] print("\n\nProcessing Data...") for s in subjects: print("processing dwi subject {}".format(s)) dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz")) dwi_data = dwi.get_data() dwi_data = np.nan_to_num(dwi_data) dwi_data = DatasetUtils.scale_input_to_unet_shape( dwi_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # dwi_slices.append(dwi_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # dwi_slices.append(dwi_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # dwi_slices.append(dwi_data[:, :, z, :]) #Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): dwi_slices.append(dwi_data[z, :, :, :]) for z in range(dwi_data.shape[1]): dwi_slices.append(dwi_data[:, z, :, :]) for z in range(dwi_data.shape[2]): dwi_slices.append(dwi_data[:, :, z, :]) dwi_slices = np.array(dwi_slices) random_idxs = None if shuffle: random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices)) dwi_slices = dwi_slices[random_idxs] np.save(filename + "_data.npy", dwi_slices) del dwi_slices #free memory print("\n\nProcessing Segs...") for s in subjects: print("processing seg subject {}".format(s)) mask_data = ImgUtils.create_multilabel_mask( HP, s, labels_type=HP.LABELS_TYPE) if HP.RESOLUTION == "2.5mm": mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5) mask_data = DatasetUtils.scale_input_to_unet_shape( mask_data, HP.DATASET, HP.RESOLUTION) # if slice == "x": # for z in range(dwi_data.shape[0]): # mask_slices.append(mask_data[z, :, :, :]) # # if slice == "y": # for z in range(dwi_data.shape[1]): # mask_slices.append(mask_data[:, z, :, :]) # # if slice == "z": # for z in range(dwi_data.shape[2]): # mask_slices.append(mask_data[:, :, z, :]) # Use slices from all directions in one dataset for z in range(dwi_data.shape[0]): mask_slices.append(mask_data[z, :, :, :]) for z in range(dwi_data.shape[1]): mask_slices.append(mask_data[:, z, :, :]) for z in range(dwi_data.shape[2]): mask_slices.append(mask_data[:, :, z, :]) mask_slices = np.array(mask_slices) print("SEG TYPE: {}".format(mask_slices.dtype)) if shuffle: mask_slices = mask_slices[random_idxs] np.save(filename + "_seg.npy", mask_slices)