Esempio n. 1
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_postprocessing=True,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False,
                 tract_segmentations_path=None,
                 TOM_dilation=1,
                 unit_test=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX
        tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks
            outside of the segmentation mask)
        TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs.

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH,
                 exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                 "Hyperparameters.txt"))

    # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary)
    if get_probs:
        bundle_specific_postprocessing = False
        postprocess = False

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config)
    Config.RESET_LAST_LAYER = False

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v4.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v2.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # xtract
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_xtract_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_xtract_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition xtract not supported in combination with this output type"
                )

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)

    #runtime on HCP data: 0.9s
    data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data)
    # runtime on HCP data: 0.5s
    data, transformation = data_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING,
            tract_definition=tract_definition)
        model = BaseModel(Config, inference=True)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size,
                                             unit_test=unit_test)
            else:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=False,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
        else:
            seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v2.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v2.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v2.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v2.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part,
                tract_definition=tract_definition)
            model = BaseModel(Config, inference=True)

            if single_orientation:
                data_loder_inference = DataLoaderInference(Config, data=data)
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
            else:
                # 3 dir for Peaks -> bad results
                seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                    Config,
                    model,
                    data=data,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
                seg = direction_merger.mean_fusion_peaks(seg_xyz,
                                                         nr_cpus=nr_cpus)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling:
        # Runtime ~4s
        seg = img_utils.bundle_specific_postprocessing(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

    # runtime on HCP data: 5.1s
    seg = data_utils.cut_and_scale_img_back_to_original_img(seg,
                                                            transformation,
                                                            nr_cpus=nr_cpus)
    # runtime on HCP data: 1.6s
    seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape,
                                                     Config.NR_OF_CLASSES)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        seg = peak_utils.mask_and_normalize_peaks(
            seg,
            tract_segmentations_path,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            TOM_dilation,
            nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling:
        # Runtime ~7s for 1.25mm resolution
        # Runtime ~1.5s for  2mm resolution
        st = time.time()
        seg = img_utils.postprocess_segmentations(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            blob_thr=blob_size_thr,
            hole_closing=None)

    exp_utils.print_verbose(
        Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg
Esempio n. 2
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_threshold=False,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX)
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition,
                                 bedpostX_input=bedpostX_input)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH, manual_exp_name, "Hyperparameters.txt"))

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config)

    if bundle_specific_threshold:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_dropout_v2.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    if bedpostX_input:
                        Config.WEIGHTS_PATH = join(
                            C.WEIGHTS_DIR,
                            "TractSeg_BXTensAg_best_weights_ep248.npz")
                    else:
                        Config.WEIGHTS_PATH = join(
                            C.WEIGHTS_DIR,
                            "pretrained_weights_tract_segmentation_v2.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v1.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # AutoPTX
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_aPTX_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_aPTX_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition AutoPTX not supported in combination with this output type"
                )
            #todo: remove when aPTX weights are loaded automatically
            if not os.path.exists(Config.WEIGHTS_PATH):
                raise FileNotFoundError(
                    "Could not find weights file: {}".format(
                        Config.WEIGHTS_PATH))

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)
    # brain_mask = img_utils.simple_brain_mask(data)
    # if Config.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data,
                          (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data)
    data, transformation = dataset_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0])

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            exp_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING)
        model = BaseModel(Config)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=True,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
            else:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=False,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
        else:
            seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v1.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v1.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v1.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v1.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part)
            data_loder_inference = DataLoaderInference(Config, data=data)
            model = BaseModel(Config)
            seg, img_y = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

        #quite fast
        if bundle_specific_threshold:
            seg = img_utils.remove_small_peaks_bundle_specific(
                seg,
                exp_utils.get_bundle_names(Config.CLASSES)[1:],
                len_thr=0.3)
        else:
            seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold)

        #3 dir for Peaks -> bad results
        # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data,
        #                                                                scale_to_world_shape=False,
        #                                                                only_prediction=True,
        #                                                                batch_size=inference_batch_size)
        # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation":
        seg = img_utils.probs_to_binary_bundle_specific(
            seg,
            exp_utils.get_bundle_names(Config.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = dataset_utils.cut_and_scale_img_back_to_original_img(
        seg, transformation)  # quite slow
    seg = dataset_utils.add_original_zero_padding_again(
        seg, bbox, original_shape, Config.NR_OF_CLASSES)  # quite slow

    if postprocess:
        seg = img_utils.postprocess_segmentations(seg,
                                                  blob_thr=blob_size_thr,
                                                  hole_closing=2)

    exp_utils.print_verbose(
        Config, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg