def get_batch_generator(self, batch_size=1):

        if self.data is not None:
            exp_utils.print_verbose(self.Config, "Loading data from PREDICT_IMG input file")
            data = np.nan_to_num(self.data)
            # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
            seg = np.zeros((self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[0],
                            self.Config.INPUT_DIM[0], self.Config.NR_OF_CLASSES)).astype(self.Config.LABELS_TYPE)
        elif self.subject is not None:
            if self.Config.TYPE == "combined":
                # Load from Npy file for Fusion
                data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, self.subject,
                                    self.Config.FEATURES_FILENAME + ".npy"), mmap_mode="r")
                seg = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, self.subject,
                                   self.Config.LABELS_FILENAME + ".npy"), mmap_mode="r")
                data = np.nan_to_num(data)
                seg = np.nan_to_num(seg)
                data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], data.shape[3] * data.shape[4]))
            else:
                from tractseg.data.data_loader_training import load_training_data
                data, seg = load_training_data(self.Config, self.subject)
        else:
            raise ValueError("Neither 'data' nor 'subject' set.")

        if self.Config.DIM == "2D":
            batch_gen = BatchGenerator2D_data_ordered_standalone((data, seg), batch_size=batch_size)
        else:
            batch_gen = BatchGenerator3D_data_ordered_standalone((data, seg), batch_size=batch_size)
        batch_gen.Config = self.Config

        batch_gen = self._augment_data(batch_gen, type=type)
        return batch_gen
Exemple #2
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_postprocessing=True,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False,
                 tract_segmentations_path=None,
                 TOM_dilation=1,
                 unit_test=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX
        tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks
            outside of the segmentation mask)
        TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs.

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH,
                 exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                 "Hyperparameters.txt"))

    # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary)
    if get_probs:
        bundle_specific_postprocessing = False
        postprocess = False

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config)
    Config.RESET_LAST_LAYER = False

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v4.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v2.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # xtract
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_xtract_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_xtract_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition xtract not supported in combination with this output type"
                )

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)

    #runtime on HCP data: 0.9s
    data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data)
    # runtime on HCP data: 0.5s
    data, transformation = data_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING,
            tract_definition=tract_definition)
        model = BaseModel(Config, inference=True)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size,
                                             unit_test=unit_test)
            else:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=False,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
        else:
            seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v2.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v2.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v2.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v2.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part,
                tract_definition=tract_definition)
            model = BaseModel(Config, inference=True)

            if single_orientation:
                data_loder_inference = DataLoaderInference(Config, data=data)
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
            else:
                # 3 dir for Peaks -> bad results
                seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                    Config,
                    model,
                    data=data,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
                seg = direction_merger.mean_fusion_peaks(seg_xyz,
                                                         nr_cpus=nr_cpus)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling:
        # Runtime ~4s
        seg = img_utils.bundle_specific_postprocessing(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

    # runtime on HCP data: 5.1s
    seg = data_utils.cut_and_scale_img_back_to_original_img(seg,
                                                            transformation,
                                                            nr_cpus=nr_cpus)
    # runtime on HCP data: 1.6s
    seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape,
                                                     Config.NR_OF_CLASSES)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        seg = peak_utils.mask_and_normalize_peaks(
            seg,
            tract_segmentations_path,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            TOM_dilation,
            nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling:
        # Runtime ~7s for 1.25mm resolution
        # Runtime ~1.5s for  2mm resolution
        st = time.time()
        seg = img_utils.postprocess_segmentations(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            blob_thr=blob_size_thr,
            hole_closing=None)

    exp_utils.print_verbose(
        Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg
Exemple #3
0
    def __init__(self, Config, inference=False):
        self.Config = Config

        # Do not use during inference because uses a lot more memory
        if not inference:
            torch.backends.cudnn.benchmark = True

        if self.Config.NR_CPUS > 0:
            torch.set_num_threads(self.Config.NR_CPUS)

        if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS
        elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined":
            self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES
        else:
            self.Config.NR_OF_GRADIENTS = 33

        if self.Config.LOSS_FUNCTION == "soft_sample_dice":
            self.criterion = pytorch_utils.soft_sample_dice
        elif self.Config.LOSS_FUNCTION == "soft_batch_dice":
            self.criterion = pytorch_utils.soft_batch_dice
        elif self.Config.EXPERIMENT_TYPE == "peak_regression":
            if self.Config.LOSS_FUNCTION == "angle_length_loss":
                self.criterion = pytorch_utils.angle_length_loss
            elif self.Config.LOSS_FUNCTION == "angle_loss":
                self.criterion = pytorch_utils.angle_loss
            elif self.Config.LOSS_FUNCTION == "l2_loss":
                self.criterion = pytorch_utils.l2_loss
        elif self.Config.EXPERIMENT_TYPE == "dm_regression":
            # self.criterion = nn.MSELoss()   # aggregate by mean
            self.criterion = nn.MSELoss(size_average=False,
                                        reduce=True)  # aggregate by sum
        else:
            self.criterion = nn.BCEWithLogitsLoss()

        NetworkClass = getattr(
            importlib.import_module("tractseg.models." +
                                    self.Config.MODEL.lower()),
            self.Config.MODEL)
        self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS,
                                n_classes=self.Config.NR_OF_CLASSES,
                                n_filt=self.Config.UNET_NR_FILT,
                                batchnorm=self.Config.BATCH_NORM,
                                dropout=self.Config.USE_DROPOUT,
                                upsample=self.Config.UPSAMPLE_TYPE)

        # MultiGPU setup
        # (Not really faster (max 10% speedup): GPU and CPU utility low)
        # nr_gpus = torch.cuda.device_count()
        # exp_utils.print_and_save(self.Config, "nr of gpus: {}".format(nr_gpus))
        # self.net = nn.DataParallel(self.net)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        net = self.net.to(self.device)

        if self.Config.OPTIMIZER == "Adamax":
            self.optimizer = Adamax(net.parameters(),
                                    lr=self.Config.LEARNING_RATE,
                                    weight_decay=self.Config.WEIGHT_DECAY)
        elif self.Config.OPTIMIZER == "Adam":
            self.optimizer = Adam(net.parameters(),
                                  lr=self.Config.LEARNING_RATE,
                                  weight_decay=self.Config.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")

        if APEX_AVAILABLE and self.Config.FP16:
            # Use O0 to disable fp16 (might be a little faster on TitanX)
            self.net, self.optimizer = amp.initialize(self.net,
                                                      self.optimizer,
                                                      verbosity=0,
                                                      opt_level="O1")
            if not inference:
                print("INFO: Using fp16 training")
        else:
            if not inference:
                print("INFO: Did not find APEX, defaulting to fp32 training")

        if self.Config.LR_SCHEDULE:
            self.scheduler = lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode=self.Config.LR_SCHEDULE_MODE,
                patience=self.Config.LR_SCHEDULE_PATIENCE)

        if self.Config.LOAD_WEIGHTS:
            exp_utils.print_verbose(
                self.Config, "Loading weights ... ({})".format(
                    join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH)))
            self.load_model(
                join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))

        # Reset weights of last layer for transfer learning
        if self.Config.RESET_LAST_LAYER:
            self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT,
                                        self.Config.NR_OF_CLASSES,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=True).to(self.device)
Exemple #4
0
    def __init__(self, Config):
        self.Config = Config

        # torch.backends.cudnn.benchmark = True     #not faster
        if self.Config.NR_CPUS > 0:
            torch.set_num_threads(self.Config.NR_CPUS)

        if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9 * 5    # 5 slices
        elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined":
            self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES
        else:
            self.Config.NR_OF_GRADIENTS = 33

        if self.Config.LOSS_FUNCTION == "soft_sample_dice":
            self.criterion = pytorch_utils.soft_sample_dice
        elif self.Config.LOSS_FUNCTION == "soft_batch_dice":
            self.criterion = pytorch_utils.soft_batch_dice
        elif self.Config.EXPERIMENT_TYPE == "peak_regression":
            self.criterion = pytorch_utils.angle_length_loss
        else:
            # weights = torch.ones((self.Config.BATCH_SIZE, self.Config.NR_OF_CLASSES,
            #                       self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # self.criterion = nn.BCEWithLogitsLoss(weight=weights)
            self.criterion = nn.BCEWithLogitsLoss()

        NetworkClass = getattr(importlib.import_module("tractseg.models." + self.Config.MODEL.lower()),
                               self.Config.MODEL)
        self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.Config.NR_OF_CLASSES,
                                n_filt=self.Config.UNET_NR_FILT, batchnorm=self.Config.BATCH_NORM,
                                dropout=self.Config.USE_DROPOUT, upsample=self.Config.UPSAMPLE_TYPE)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net = self.net.to(self.device)

        # if self.Config.TRAIN:
        #     exp_utils.print_and_save(self.Config, str(net), only_log=True)    # print network

        if self.Config.OPTIMIZER == "Adamax":
            self.optimizer = Adamax(net.parameters(), lr=self.Config.LEARNING_RATE)
        elif self.Config.OPTIMIZER == "Adam":
            self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE)
            # self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE,
            #                       weight_decay=self.Config.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")

        if self.Config.LR_SCHEDULE:
            self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1)
            # self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="max")

        if self.Config.LOAD_WEIGHTS:
            exp_utils.print_verbose(self.Config, "Loading weights ... ({})".format(join(self.Config.EXP_PATH,
                                                                                        self.Config.WEIGHTS_PATH)))
            self.load_model(join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))

        if self.Config.RESET_LAST_LAYER:
            self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT, self.Config.NR_OF_CLASSES, kernel_size=1,
                                        stride=1, padding=0, bias=True).to(self.device)
Exemple #5
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_threshold=False,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX)
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition,
                                 bedpostX_input=bedpostX_input)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH, manual_exp_name, "Hyperparameters.txt"))

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config)

    if bundle_specific_threshold:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_dropout_v2.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    if bedpostX_input:
                        Config.WEIGHTS_PATH = join(
                            C.WEIGHTS_DIR,
                            "TractSeg_BXTensAg_best_weights_ep248.npz")
                    else:
                        Config.WEIGHTS_PATH = join(
                            C.WEIGHTS_DIR,
                            "pretrained_weights_tract_segmentation_v2.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v1.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # AutoPTX
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_aPTX_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_aPTX_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition AutoPTX not supported in combination with this output type"
                )
            #todo: remove when aPTX weights are loaded automatically
            if not os.path.exists(Config.WEIGHTS_PATH):
                raise FileNotFoundError(
                    "Could not find weights file: {}".format(
                        Config.WEIGHTS_PATH))

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)
    # brain_mask = img_utils.simple_brain_mask(data)
    # if Config.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data,
                          (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data)
    data, transformation = dataset_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0])

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            exp_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING)
        model = BaseModel(Config)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=True,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
            else:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=False,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
        else:
            seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v1.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v1.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v1.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v1.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part)
            data_loder_inference = DataLoaderInference(Config, data=data)
            model = BaseModel(Config)
            seg, img_y = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

        #quite fast
        if bundle_specific_threshold:
            seg = img_utils.remove_small_peaks_bundle_specific(
                seg,
                exp_utils.get_bundle_names(Config.CLASSES)[1:],
                len_thr=0.3)
        else:
            seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold)

        #3 dir for Peaks -> bad results
        # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data,
        #                                                                scale_to_world_shape=False,
        #                                                                only_prediction=True,
        #                                                                batch_size=inference_batch_size)
        # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation":
        seg = img_utils.probs_to_binary_bundle_specific(
            seg,
            exp_utils.get_bundle_names(Config.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = dataset_utils.cut_and_scale_img_back_to_original_img(
        seg, transformation)  # quite slow
    seg = dataset_utils.add_original_zero_padding_again(
        seg, bbox, original_shape, Config.NR_OF_CLASSES)  # quite slow

    if postprocess:
        seg = img_utils.postprocess_segmentations(seg,
                                                  blob_thr=blob_size_thr,
                                                  hole_closing=2)

    exp_utils.print_verbose(
        Config, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg
Exemple #6
0
    def __init__(self, Config, inference=False):
        self.Config = Config

        if not inference:
            torch.backends.cudnn.benchmark = True

        if self.Config.NR_CPUS > 0:
            torch.set_num_threads(self.Config.NR_CPUS)

        if self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.Config.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9 * 5    # 5 slices
        elif self.Config.SEG_INPUT == "Peaks" and self.Config.TYPE == "combined":
            self.Config.NR_OF_GRADIENTS = 3 * self.Config.NR_OF_CLASSES
        else:
            self.Config.NR_OF_GRADIENTS = 33

        if self.Config.LOSS_FUNCTION == "soft_sample_dice":
            self.criterion = pytorch_utils.soft_sample_dice
        elif self.Config.LOSS_FUNCTION == "soft_batch_dice":
            self.criterion = pytorch_utils.soft_batch_dice
        elif self.Config.EXPERIMENT_TYPE == "peak_regression":
            if self.Config.LOSS_FUNCTION == "angle_length_loss":
                self.criterion = pytorch_utils.angle_length_loss
            elif self.Config.LOSS_FUNCTION == "angle_loss":
                self.criterion = pytorch_utils.angle_loss
            elif self.Config.LOSS_FUNCTION == "l2_loss":
                self.criterion = pytorch_utils.l2_loss
        elif self.Config.EXPERIMENT_TYPE == "dm_regression":
            # self.criterion = nn.MSELoss()   # aggregate by mean
            self.criterion = nn.MSELoss(size_average=False,
                                        reduce=True)  # aggregate by sum
        else:
            # weights = torch.ones((self.Config.BATCH_SIZE, self.Config.NR_OF_CLASSES,
            #                       self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # self.criterion = nn.BCEWithLogitsLoss(weight=weights)
            self.criterion = nn.BCEWithLogitsLoss()

        NetworkClass = getattr(
            importlib.import_module("tractseg.models." +
                                    self.Config.MODEL.lower()),
            self.Config.MODEL)
        self.net = NetworkClass(n_input_channels=NR_OF_GRADIENTS,
                                n_classes=self.Config.NR_OF_CLASSES,
                                n_filt=self.Config.UNET_NR_FILT,
                                batchnorm=self.Config.BATCH_NORM,
                                dropout=self.Config.USE_DROPOUT,
                                upsample=self.Config.UPSAMPLE_TYPE)

        # Somehow not really faster (max 10% speedup): GPU utility low -> why? (CPU also low)
        # (with bigger batch_size even worse)
        # - GPU slow connection? (but maybe same problem as before pin_memory)
        # - Wrong setup with pin_memory, async, ...? -> should be correct
        # - load from npy instead of nii -> will not solve entire problem
        # nr_gpus = torch.cuda.device_count()
        # exp_utils.print_and_save(self.Config, "nr of gpus: {}".format(nr_gpus))
        # self.net = nn.DataParallel(self.net)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        net = self.net.to(self.device)

        # if self.Config.TRAIN:
        #     exp_utils.print_and_save(self.Config, str(net), only_log=True)    # print network

        if self.Config.OPTIMIZER == "Adamax":
            self.optimizer = Adamax(net.parameters(),
                                    lr=self.Config.LEARNING_RATE)
        elif self.Config.OPTIMIZER == "Adam":
            self.optimizer = Adam(net.parameters(),
                                  lr=self.Config.LEARNING_RATE)
            # self.optimizer = Adam(net.parameters(), lr=self.Config.LEARNING_RATE,
            #                       weight_decay=self.Config.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")

        if APEX_AVAILABLE and self.Config.FP16:
            # Use O0 to disable fp16 (might be a little faster on TitanX)
            self.net, self.optimizer = amp.initialize(self.net,
                                                      self.optimizer,
                                                      verbosity=0,
                                                      opt_level="O1")
            if not inference:
                print("INFO: Using fp16 training")
        else:
            if not inference:
                print("INFO: Did not find APEX, defaulting to fp32 training")

        if self.Config.LR_SCHEDULE:
            # Slightly better results could be archived if training for 500ep without reduction of LR
            # -> but takes too long -> using reudceOnPlateau gives benefits if only training for 200ep
            self.scheduler = lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode=self.Config.LR_SCHEDULE_MODE,
                patience=self.Config.LR_SCHEDULE_PATIENCE)

        if self.Config.LOAD_WEIGHTS:
            exp_utils.print_verbose(
                self.Config, "Loading weights ... ({})".format(
                    join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH)))
            self.load_model(
                join(self.Config.EXP_PATH, self.Config.WEIGHTS_PATH))

        if self.Config.RESET_LAST_LAYER:
            self.net.conv_5 = nn.Conv2d(self.Config.UNET_NR_FILT,
                                        self.Config.NR_OF_CLASSES,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=True).to(self.device)