def get_batch_generator(self, batch_size=1): if self.data is not None: exp_utils.print_verbose(self.Config, "Loading data from PREDICT_IMG input file") data = np.nan_to_num(self.data) # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros((self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[0], self.Config.NR_OF_CLASSES)).astype(self.Config.LABELS_TYPE) elif self.subject is not None: if self.Config.TYPE == "combined": # Load from Npy file for Fusion data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, self.subject, self.Config.FEATURES_FILENAME + ".npy"), mmap_mode="r") seg = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, self.subject, self.Config.LABELS_FILENAME + ".npy"), mmap_mode="r") data = np.nan_to_num(data) seg = np.nan_to_num(seg) data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], data.shape[3] * data.shape[4])) else: from tractseg.data.data_loader_training import load_training_data data, seg = load_training_data(self.Config, self.subject) else: raise ValueError("Neither 'data' nor 'subject' set.") if self.Config.DIM == "2D": batch_gen = BatchGenerator2D_data_ordered_standalone((data, seg), batch_size=batch_size) else: batch_gen = BatchGenerator3D_data_ordered_standalone((data, seg), batch_size=batch_size) batch_gen.Config = self.Config batch_gen = self._augment_data(batch_gen, type=type) return batch_gen
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False, tract_segmentations_path=None, TOM_dilation=1, unit_test=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks outside of the segmentation mask) TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs. Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"), "Hyperparameters.txt")) # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary) if get_probs: bundle_specific_postprocessing = False postprocess = False Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config) Config.RESET_LAST_LAYER = False if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v4.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v2.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # xtract if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_xtract_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_xtract_v1.npz") else: raise ValueError( "bundle_definition xtract not supported in combination with this output type" ) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) #runtime on HCP data: 0.9s data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data) # runtime on HCP data: 0.5s data, transformation = data_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size, unit_test=unit_test) else: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v2.npz", "Part2": "pretrained_weights_peak_regression_part2_v2.npz", "Part3": "pretrained_weights_peak_regression_part3_v2.npz", "Part4": "pretrained_weights_peak_regression_part4_v2.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: data_loder_inference = DataLoaderInference(Config, data=data) seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: # 3 dir for Peaks -> bad results seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) seg = direction_merger.mean_fusion_peaks(seg_xyz, nr_cpus=nr_cpus) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling: # Runtime ~4s seg = img_utils.bundle_specific_postprocessing( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) # runtime on HCP data: 5.1s seg = data_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus) # runtime on HCP data: 1.6s seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES) if Config.EXPERIMENT_TYPE == "peak_regression": seg = peak_utils.mask_and_normalize_peaks( seg, tract_segmentations_path, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], TOM_dilation, nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling: # Runtime ~7s for 1.25mm resolution # Runtime ~1.5s for 2mm resolution st = time.time() seg = img_utils.postprocess_segmentations( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], blob_thr=blob_size_thr, hole_closing=None) exp_utils.print_verbose( Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def __init__(self, Config, inference=False): self.Config = Config # Do not use during inference because uses a lot more memory if not inference: torch.backends.cudnn.benchmark = True if self.Config.NR_CPUS > 0: torch.set_num_threads(self.Config.NR_CPUS) if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction": NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined": self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES else: self.Config.NR_OF_GRADIENTS = 33 if self.Config.LOSS_FUNCTION == "soft_sample_dice": self.criterion = pytorch_utils.soft_sample_dice elif self.Config.LOSS_FUNCTION == "soft_batch_dice": self.criterion = pytorch_utils.soft_batch_dice elif self.Config.EXPERIMENT_TYPE == "peak_regression": if self.Config.LOSS_FUNCTION == "angle_length_loss": self.criterion = pytorch_utils.angle_length_loss elif self.Config.LOSS_FUNCTION == "angle_loss": self.criterion = pytorch_utils.angle_loss elif self.Config.LOSS_FUNCTION == "l2_loss": self.criterion = pytorch_utils.l2_loss elif self.Config.EXPERIMENT_TYPE == "dm_regression": # self.criterion = nn.MSELoss() # aggregate by mean self.criterion = nn.MSELoss(size_average=False, reduce=True) # aggregate by sum else: self.criterion = nn.BCEWithLogitsLoss() NetworkClass = getattr( importlib.import_module("tractseg.models." + self.Config.MODEL.lower()), self.Config.MODEL) self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.Config.NR_OF_CLASSES, n_filt=self.Config.UNET_NR_FILT, batchnorm=self.Config.BATCH_NORM, dropout=self.Config.USE_DROPOUT, upsample=self.Config.UPSAMPLE_TYPE) # MultiGPU setup # (Not really faster (max 10% speedup): GPU and CPU utility low) # nr_gpus = torch.cuda.device_count() # exp_utils.print_and_save(self.Config, "nr of gpus: {}".format(nr_gpus)) # self.net = nn.DataParallel(self.net) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") net = self.net.to(self.device) if self.Config.OPTIMIZER == "Adamax": self.optimizer = Adamax(net.parameters(), lr=self.Config.LEARNING_RATE, weight_decay=self.Config.WEIGHT_DECAY) elif self.Config.OPTIMIZER == "Adam": self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE, weight_decay=self.Config.WEIGHT_DECAY) else: raise ValueError("Optimizer not defined") if APEX_AVAILABLE and self.Config.FP16: # Use O0 to disable fp16 (might be a little faster on TitanX) self.net, self.optimizer = amp.initialize(self.net, self.optimizer, verbosity=0, opt_level="O1") if not inference: print("INFO: Using fp16 training") else: if not inference: print("INFO: Did not find APEX, defaulting to fp32 training") if self.Config.LR_SCHEDULE: self.scheduler = lr_scheduler.ReduceLROnPlateau( self.optimizer, mode=self.Config.LR_SCHEDULE_MODE, patience=self.Config.LR_SCHEDULE_PATIENCE) if self.Config.LOAD_WEIGHTS: exp_utils.print_verbose( self.Config, "Loading weights ... ({})".format( join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))) self.load_model( join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH)) # Reset weights of last layer for transfer learning if self.Config.RESET_LAST_LAYER: self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT, self.Config.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(self.device)
def __init__(self, Config): self.Config = Config # torch.backends.cudnn.benchmark = True #not faster if self.Config.NR_CPUS > 0: torch.set_num_threads(self.Config.NR_CPUS) if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction": NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS # NR_OF_GRADIENTS = 9 * 5 # 5 slices elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined": self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES else: self.Config.NR_OF_GRADIENTS = 33 if self.Config.LOSS_FUNCTION == "soft_sample_dice": self.criterion = pytorch_utils.soft_sample_dice elif self.Config.LOSS_FUNCTION == "soft_batch_dice": self.criterion = pytorch_utils.soft_batch_dice elif self.Config.EXPERIMENT_TYPE == "peak_regression": self.criterion = pytorch_utils.angle_length_loss else: # weights = torch.ones((self.Config.BATCH_SIZE, self.Config.NR_OF_CLASSES, # self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[1])).cuda() # weights[:, 5, :, :] *= 10 #CA # weights[:, 21, :, :] *= 10 #FX_left # weights[:, 22, :, :] *= 10 #FX_right # self.criterion = nn.BCEWithLogitsLoss(weight=weights) self.criterion = nn.BCEWithLogitsLoss() NetworkClass = getattr(importlib.import_module("tractseg.models." + self.Config.MODEL.lower()), self.Config.MODEL) self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.Config.NR_OF_CLASSES, n_filt=self.Config.UNET_NR_FILT, batchnorm=self.Config.BATCH_NORM, dropout=self.Config.USE_DROPOUT, upsample=self.Config.UPSAMPLE_TYPE) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = self.net.to(self.device) # if self.Config.TRAIN: # exp_utils.print_and_save(self.Config, str(net), only_log=True) # print network if self.Config.OPTIMIZER == "Adamax": self.optimizer = Adamax(net.parameters(), lr=self.Config.LEARNING_RATE) elif self.Config.OPTIMIZER == "Adam": self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE) # self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE, # weight_decay=self.Config.WEIGHT_DECAY) else: raise ValueError("Optimizer not defined") if self.Config.LR_SCHEDULE: self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1) # self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="max") if self.Config.LOAD_WEIGHTS: exp_utils.print_verbose(self.Config, "Loading weights ... ({})".format(join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))) self.load_model(join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH)) if self.Config.RESET_LAST_LAYER: self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT, self.Config.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(self.device)
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX) get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition, bedpostX_input=bedpostX_input) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, manual_exp_name, "Hyperparameters.txt")) Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config) if bundle_specific_threshold: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_dropout_v2.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": if bedpostX_input: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "TractSeg_BXTensAg_best_weights_ep248.npz") else: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # AutoPTX if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_aPTX_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_aPTX_v1.npz") else: raise ValueError( "bundle_definition AutoPTX not supported in combination with this output type" ) #todo: remove when aPTX weights are loaded automatically if not os.path.exists(Config.WEIGHTS_PATH): raise FileNotFoundError( "Could not find weights file: {}".format( Config.WEIGHTS_PATH)) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) # brain_mask = img_utils.simple_brain_mask(data) # if Config.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data) data, transformation = dataset_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0]) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING) model = BaseModel(Config) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, gt = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v1.npz", "Part2": "pretrained_weights_peak_regression_part2_v1.npz", "Part3": "pretrained_weights_peak_regression_part3_v1.npz", "Part4": "pretrained_weights_peak_regression_part4_v1.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part) data_loder_inference = DataLoaderInference(Config, data=data) model = BaseModel(Config) seg, img_y = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all #quite fast if bundle_specific_threshold: seg = img_utils.remove_small_peaks_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:], len_thr=0.3) else: seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> bad results # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data, # scale_to_world_shape=False, # only_prediction=True, # batch_size=inference_batch_size) # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation": seg = img_utils.probs_to_binary_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:]) #remove following two lines to keep super resolution seg = dataset_utils.cut_and_scale_img_back_to_original_img( seg, transformation) # quite slow seg = dataset_utils.add_original_zero_padding_again( seg, bbox, original_shape, Config.NR_OF_CLASSES) # quite slow if postprocess: seg = img_utils.postprocess_segmentations(seg, blob_thr=blob_size_thr, hole_closing=2) exp_utils.print_verbose( Config, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def __init__(self, Config, inference=False): self.Config = Config if not inference: torch.backends.cudnn.benchmark = True if self.Config.NR_CPUS > 0: torch.set_num_threads(self.Config.NR_CPUS) if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction": NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS # NR_OF_GRADIENTS = 9 * 5 # 5 slices elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined": self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES else: self.Config.NR_OF_GRADIENTS = 33 if self.Config.LOSS_FUNCTION == "soft_sample_dice": self.criterion = pytorch_utils.soft_sample_dice elif self.Config.LOSS_FUNCTION == "soft_batch_dice": self.criterion = pytorch_utils.soft_batch_dice elif self.Config.EXPERIMENT_TYPE == "peak_regression": if self.Config.LOSS_FUNCTION == "angle_length_loss": self.criterion = pytorch_utils.angle_length_loss elif self.Config.LOSS_FUNCTION == "angle_loss": self.criterion = pytorch_utils.angle_loss elif self.Config.LOSS_FUNCTION == "l2_loss": self.criterion = pytorch_utils.l2_loss elif self.Config.EXPERIMENT_TYPE == "dm_regression": # self.criterion = nn.MSELoss() # aggregate by mean self.criterion = nn.MSELoss(size_average=False, reduce=True) # aggregate by sum else: # weights = torch.ones((self.Config.BATCH_SIZE, self.Config.NR_OF_CLASSES, # self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[1])).cuda() # weights[:, 5, :, :] *= 10 #CA # weights[:, 21, :, :] *= 10 #FX_left # weights[:, 22, :, :] *= 10 #FX_right # self.criterion = nn.BCEWithLogitsLoss(weight=weights) self.criterion = nn.BCEWithLogitsLoss() NetworkClass = getattr( importlib.import_module("tractseg.models." + self.Config.MODEL.lower()), self.Config.MODEL) self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.Config.NR_OF_CLASSES, n_filt=self.Config.UNET_NR_FILT, batchnorm=self.Config.BATCH_NORM, dropout=self.Config.USE_DROPOUT, upsample=self.Config.UPSAMPLE_TYPE) # Somehow not really faster (max 10% speedup): GPU utility low -> why? (CPU also low) # (with bigger batch_size even worse) # - GPU slow connection? (but maybe same problem as before pin_memory) # - Wrong setup with pin_memory, async, ...? -> should be correct # - load from npy instead of nii -> will not solve entire problem # nr_gpus = torch.cuda.device_count() # exp_utils.print_and_save(self.Config, "nr of gpus: {}".format(nr_gpus)) # self.net = nn.DataParallel(self.net) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") net = self.net.to(self.device) # if self.Config.TRAIN: # exp_utils.print_and_save(self.Config, str(net), only_log=True) # print network if self.Config.OPTIMIZER == "Adamax": self.optimizer = Adamax(net.parameters(), lr=self.Config.LEARNING_RATE) elif self.Config.OPTIMIZER == "Adam": self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE) # self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE, # weight_decay=self.Config.WEIGHT_DECAY) else: raise ValueError("Optimizer not defined") if APEX_AVAILABLE and self.Config.FP16: # Use O0 to disable fp16 (might be a little faster on TitanX) self.net, self.optimizer = amp.initialize(self.net, self.optimizer, verbosity=0, opt_level="O1") if not inference: print("INFO: Using fp16 training") else: if not inference: print("INFO: Did not find APEX, defaulting to fp32 training") if self.Config.LR_SCHEDULE: # Slightly better results could be archived if training for 500ep without reduction of LR # -> but takes too long -> using reudceOnPlateau gives benefits if only training for 200ep self.scheduler = lr_scheduler.ReduceLROnPlateau( self.optimizer, mode=self.Config.LR_SCHEDULE_MODE, patience=self.Config.LR_SCHEDULE_PATIENCE) if self.Config.LOAD_WEIGHTS: exp_utils.print_verbose( self.Config, "Loading weights ... ({})".format( join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))) self.load_model( join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH)) if self.Config.RESET_LAST_LAYER: self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT, self.Config.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(self.device)