def create_preprocessed_files(subject): # Estimate bounding box from this file and then apply it to all other files bb_file = "12g_125mm_peaks" # todo: adapt # filenames_data = ["12g_125mm_peaks", "90g_125mm_peaks", "270g_125mm_peaks", # "12g_125mm_bedpostx_peaks_scaled", "90g_125mm_bedpostx_peaks_scaled", # "270g_125mm_bedpostx_peaks_scaled"] # filenames_seg = ["bundle_masks_72", "bundle_masks_dm", "endpoints_72_ordered", # "bundle_peaks_Part1", "bundle_peaks_Part2", "bundle_peaks_Part3", "bundle_peaks_Part4", # "bundle_masks_autoPTX_dm", "bundle_masks_autoPTX_thr001"] filenames_data = ["bundle_uncertainties"] filenames_seg = [] print("idx: {}".format(subjects.index(subject))) exp_utils.make_dir(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject)) # Get bounding box data = nib.load( join(C.NETWORK_DRIVE, DATASET_FOLDER, subject, bb_file + ".nii.gz")).get_data() _, _, bbox, _ = data_utils.crop_to_nonzero(np.nan_to_num(data)) for idx, filename in enumerate(filenames_data): path = join(C.NETWORK_DRIVE, DATASET_FOLDER, subject, filename + ".nii.gz") if os.path.exists(path): img = nib.load(path) data = img.get_data() affine = img.affine data = np.nan_to_num(data) # Add channel dimension if does not exist yet if len(data.shape) == 3: data = data[..., None] data, _, _, _ = data_utils.crop_to_nonzero(data, bbox=bbox) # np.save(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".npy"), data) nib.save( nib.Nifti1Image(data, affine), join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".nii.gz")) else: print("skipping file: {}-{}".format(subject, idx)) raise IOError("File missing") for filename in filenames_seg: img = nib.load( join(C.NETWORK_DRIVE, DATASET_FOLDER, subject, filename + ".nii.gz")) data = img.get_data() data, _, _, _ = data_utils.crop_to_nonzero(data, bbox=bbox) # np.save(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".npy"), data) nib.save( nib.Nifti1Image(data, img.affine), join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".nii.gz"))
def preprocess_nifti_file(tom_fn, mask_fn, seeds_fn, ends_fn, beginning_fn, ending_fn): # Get the appropriate affine transformation from the mask mask_object = nib.load(mask_fn[0]) mask_data = mask_object.get_data() affine = mask_object.affine # Compute the bounding box using the mask if np.sum(mask_data) != 0: bbox = data_utils.get_bbox_from_mask(np.nan_to_num(mask_data), 0) else: bbox = [[0,mask_data.shape[0]], [0,mask_data.shape[1]], [0,mask_data.shape[2]]] # Perform the cropping/padding for all the files: for fn_in, fn_out in [tom_fn, mask_fn, seeds_fn, ends_fn, beginning_fn, ending_fn]: data = nib.load(fn_in).get_data() # Adjust the data to have 4 dims, since the cropping will iterate over this dim if len(data.shape) == 3: data = data[..., None] # Crop and pad to create a cube volume of size 144 x 144 x 144 data, _, _, _ = data_utils.crop_to_nonzero(np.nan_to_num(data), bbox=bbox) data, transform = data_utils.pad_and_scale_img_to_square_img(data) # Save the new file nib.save(nib.Nifti1Image(data, affine), fn_out)
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False, tract_segmentations_path=None, TOM_dilation=1, unit_test=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks outside of the segmentation mask) TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs. Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"), "Hyperparameters.txt")) # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary) if get_probs: bundle_specific_postprocessing = False postprocess = False Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config) Config.RESET_LAST_LAYER = False if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v4.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v2.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # xtract if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_xtract_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_xtract_v1.npz") else: raise ValueError( "bundle_definition xtract not supported in combination with this output type" ) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) #runtime on HCP data: 0.9s data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data) # runtime on HCP data: 0.5s data, transformation = data_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size, unit_test=unit_test) else: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v2.npz", "Part2": "pretrained_weights_peak_regression_part2_v2.npz", "Part3": "pretrained_weights_peak_regression_part3_v2.npz", "Part4": "pretrained_weights_peak_regression_part4_v2.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: data_loder_inference = DataLoaderInference(Config, data=data) seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: # 3 dir for Peaks -> bad results seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) seg = direction_merger.mean_fusion_peaks(seg_xyz, nr_cpus=nr_cpus) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling: # Runtime ~4s seg = img_utils.bundle_specific_postprocessing( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) # runtime on HCP data: 5.1s seg = data_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus) # runtime on HCP data: 1.6s seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES) if Config.EXPERIMENT_TYPE == "peak_regression": seg = peak_utils.mask_and_normalize_peaks( seg, tract_segmentations_path, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], TOM_dilation, nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling: # Runtime ~7s for 1.25mm resolution # Runtime ~1.5s for 2mm resolution st = time.time() seg = img_utils.postprocess_segmentations( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], blob_thr=blob_size_thr, hole_closing=None) exp_utils.print_verbose( Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2))) return seg