def get_seg_single_img_3_directions(HP, model, subject=None, data=None, scale_to_world_shape=True, only_prediction=False): ''' Returns probs :param HP: :param model: :param subject: :param data: :param scale_to_world_shape: :return: ''' from tractseg.libs.Trainer import Trainer prob_slices = [] directions = ["x", "y", "z"] for idx, direction in enumerate(directions): HP.SLICE_DIRECTION = direction print("Processing direction ({} of 3)".format(idx + 1)) # print("Processing direction " + HP.SLICE_DIRECTION) if subject: dataManagerSingle = DataManagerSingleSubjectById( HP, subject=subject) else: dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) img_probs, img_y = trainerSingle.get_seg_single_img( HP, probs=True, scale_to_world_shape=scale_to_world_shape, only_prediction=only_prediction) # (x, y, z, nrClasses) prob_slices.append(img_probs) probs_x, probs_y, probs_z = prob_slices new_shape = probs_x.shape + ( 1, ) # (x, y, z, nr_classes) -> (x, y, z, nr_classes, 1) probs_x = np.reshape(probs_x, new_shape) probs_y = np.reshape(probs_y, new_shape) probs_z = np.reshape(probs_z, new_shape) probs_combined = np.concatenate((probs_x, probs_y, probs_z), axis=4) # (146, 174, 146, 45, 3) return probs_combined, img_y
def get_seg_single_img_3_directions(HP, model, subject=None, data=None, scale_to_world_shape=True, only_prediction=False): ''' Returns probs :param HP: :param model: :param subject: :param data: :param scale_to_world_shape: :return: ''' from tractseg.libs.Trainer import Trainer prob_slices = [] directions = ["x", "y", "z"] for idx, direction in enumerate(directions): HP.SLICE_DIRECTION = direction print("Processing direction ({} of 3)".format(idx+1)) # print("Processing direction " + HP.SLICE_DIRECTION) if subject: dataManagerSingle = DataManagerSingleSubjectById(HP, subject=subject) else: dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) img_probs, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=scale_to_world_shape, only_prediction=only_prediction) # (x, y, z, nrClasses) prob_slices.append(img_probs) probs_x, probs_y, probs_z = prob_slices new_shape = probs_x.shape + (1,) # (x, y, z, nr_classes) -> (x, y, z, nr_classes, 1) probs_x = np.reshape(probs_x, new_shape) probs_y = np.reshape(probs_y, new_shape) probs_z = np.reshape(probs_z, new_shape) probs_combined = np.concatenate((probs_x, probs_y, probs_z), axis=4) # (146, 174, 146, 45, 3) return probs_combined, img_y
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3) # set lower for more sensitivity #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :param peak_threshold: all peaks shorter than peak_threshold will be set to zero :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg