def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False, tract_segmentations_path=None, TOM_dilation=1, unit_test=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks outside of the segmentation mask) TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs. Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"), "Hyperparameters.txt")) # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary) if get_probs: bundle_specific_postprocessing = False postprocess = False Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config) Config.RESET_LAST_LAYER = False if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v4.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v2.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # xtract if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_xtract_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_xtract_v1.npz") else: raise ValueError( "bundle_definition xtract not supported in combination with this output type" ) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) #runtime on HCP data: 0.9s data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data) # runtime on HCP data: 0.5s data, transformation = data_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size, unit_test=unit_test) else: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v2.npz", "Part2": "pretrained_weights_peak_regression_part2_v2.npz", "Part3": "pretrained_weights_peak_regression_part3_v2.npz", "Part4": "pretrained_weights_peak_regression_part4_v2.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: data_loder_inference = DataLoaderInference(Config, data=data) seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: # 3 dir for Peaks -> bad results seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) seg = direction_merger.mean_fusion_peaks(seg_xyz, nr_cpus=nr_cpus) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling: # Runtime ~4s seg = img_utils.bundle_specific_postprocessing( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) # runtime on HCP data: 5.1s seg = data_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus) # runtime on HCP data: 1.6s seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES) if Config.EXPERIMENT_TYPE == "peak_regression": seg = peak_utils.mask_and_normalize_peaks( seg, tract_segmentations_path, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], TOM_dilation, nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling: # Runtime ~7s for 1.25mm resolution # Runtime ~1.5s for 2mm resolution st = time.time() seg = img_utils.postprocess_segmentations( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], blob_thr=blob_size_thr, hole_closing=None) exp_utils.print_verbose( Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX) get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition, bedpostX_input=bedpostX_input) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, manual_exp_name, "Hyperparameters.txt")) Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config) if bundle_specific_threshold: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_dropout_v2.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": if bedpostX_input: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "TractSeg_BXTensAg_best_weights_ep248.npz") else: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # AutoPTX if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_aPTX_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_aPTX_v1.npz") else: raise ValueError( "bundle_definition AutoPTX not supported in combination with this output type" ) #todo: remove when aPTX weights are loaded automatically if not os.path.exists(Config.WEIGHTS_PATH): raise FileNotFoundError( "Could not find weights file: {}".format( Config.WEIGHTS_PATH)) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) # brain_mask = img_utils.simple_brain_mask(data) # if Config.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data) data, transformation = dataset_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0]) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING) model = BaseModel(Config) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, gt = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v1.npz", "Part2": "pretrained_weights_peak_regression_part2_v1.npz", "Part3": "pretrained_weights_peak_regression_part3_v1.npz", "Part4": "pretrained_weights_peak_regression_part4_v1.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part) data_loder_inference = DataLoaderInference(Config, data=data) model = BaseModel(Config) seg, img_y = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all #quite fast if bundle_specific_threshold: seg = img_utils.remove_small_peaks_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:], len_thr=0.3) else: seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> bad results # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data, # scale_to_world_shape=False, # only_prediction=True, # batch_size=inference_batch_size) # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation": seg = img_utils.probs_to_binary_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:]) #remove following two lines to keep super resolution seg = dataset_utils.cut_and_scale_img_back_to_original_img( seg, transformation) # quite slow seg = dataset_utils.add_original_zero_padding_again( seg, bbox, original_shape, Config.NR_OF_CLASSES) # quite slow if postprocess: seg = img_utils.postprocess_segmentations(seg, blob_thr=blob_size_thr, hole_closing=2) exp_utils.print_verbose( Config, "Took {}s".format(round(time.time() - start_time, 2))) return seg