Exemplo n.º 1
0
    def get_cv_fold(fold, dataset="HCP"):
        '''
        Brauche train-test-validate wegen Best-model selection und wegen training von combined net
        :return:
        '''

        #For CV
        if fold == 0:
            train, validate, test = [0, 1, 2], [3], [4]
            # train, validate, test = [0, 1, 2, 3, 4], [3], [4]
        elif fold == 1:
            train, validate, test = [1, 2, 3], [4], [0]
        elif fold == 2:
            train, validate, test = [2, 3, 4], [0], [1]
        elif fold == 3:
            train, validate, test = [3, 4, 0], [1], [2]
        elif fold == 4:
            train, validate, test = [4, 0, 1], [2], [3]

        subjects = get_all_subjects(dataset)

        if dataset.startswith("HCP"):
            # subjects = list(Utils.chunks(subjects[:100], 10))   #10 folds
            subjects = list(Utils.chunks(subjects, 21))  #5 folds a 21 subjects
            # => 5 fold CV ok (score only 1%-point worse than 10 folds (80 vs 60 train subjects) (10 Fold CV impractical!)
        elif dataset.startswith("Schizo"):
            # 410 subjects
            subjects = list(Utils.chunks(subjects,
                                         82))  # 5 folds a 82 subjects
        else:
            raise ValueError("Invalid dataset name")

        subjects = np.array(subjects)
        return list(subjects[train].flatten()), list(
            subjects[validate].flatten()), list(subjects[test].flatten())
Exemplo n.º 2
0
    def compress_streamlines(streamlines, error_threshold=0.1):
        nr_processes = psutil.cpu_count()
        number_streamlines = len(streamlines)

        if nr_processes >= number_streamlines:
            nr_processes = number_streamlines - 1
            if nr_processes < 1:
                nr_processes = 1

        chunk_size = int(number_streamlines / nr_processes)

        if chunk_size < 1:
            # logging.warning("\nReturning early because chunk_size=0")
            return streamlines
        fiber_batches = list(Utils.chunks(streamlines, chunk_size))

        global COMPRESSION_ERROR_THRESHOLD
        global FIBER_BATCHES
        COMPRESSION_ERROR_THRESHOLD = error_threshold
        FIBER_BATCHES = fiber_batches

        # logging.debug("Main program using: {} GB".format(round(Utils.mem_usage(print_usage=False), 3)))
        pool = multiprocessing.Pool(processes=nr_processes)

        #Do not pass data in (doubles amount of memory needed), but only idx of shared memory (needs only as much memory as single
        # thread version (only main thread needs memory, others almost 0).
        # Shared memory version also faster (around 20-30%?).
        # Needed otherwise memory problems when processing the raw tracking output (on disk >10GB and in memory >20GB)
        result = pool.map(compress_fibers_worker_shared_mem,
                          range(0, len(fiber_batches)))

        streamlines_c = Utils.flatten(result)
        return streamlines_c
Exemplo n.º 3
0
    def get_cv_fold(fold, dataset="HCP"):
        '''
        Brauche train-test-validate wegen Best-model selection und wegen training von combined net
        :return:
        '''

        #For CV
        if fold == 0:
            train, validate, test = [0, 1, 2], [3], [4]
            # train, validate, test = [0, 1, 2, 3, 4], [3], [4]
        elif fold == 1:
            train, validate, test = [1, 2, 3], [4], [0]
        elif fold == 2:
            train, validate, test = [2, 3, 4], [0], [1]
        elif fold == 3:
            train, validate, test = [3, 4, 0], [1], [2]
        elif fold == 4:
            train, validate, test = [4, 0, 1], [2], [3]

        subjects = get_all_subjects(dataset)

        if dataset.startswith("HCP"):
            # subjects = list(Utils.chunks(subjects[:100], 10))   #10 folds
            subjects = list(Utils.chunks(subjects, 21))   #5 folds a 21 subjects
            # => 5 fold CV ok (score only 1%-point worse than 10 folds (80 vs 60 train subjects) (10 Fold CV impractical!)
        elif dataset.startswith("Schizo"):
            # 410 subjects
            subjects = list(Utils.chunks(subjects, 82))  # 5 folds a 82 subjects
        else:
            raise ValueError("Invalid dataset name")

        subjects = np.array(subjects)
        return list(subjects[train].flatten()), list(subjects[validate].flatten()), list(subjects[test].flatten())
Exemplo n.º 4
0
    def compress_streamlines(streamlines, error_threshold=0.1):
        nr_processes = psutil.cpu_count()
        number_streamlines = len(streamlines)

        if nr_processes >= number_streamlines:
            nr_processes = number_streamlines - 1
            if nr_processes < 1:
                nr_processes = 1

        chunk_size = int(number_streamlines / nr_processes)

        if chunk_size < 1:
            # logging.warning("\nReturning early because chunk_size=0")
            return streamlines
        fiber_batches = list(Utils.chunks(streamlines, chunk_size))

        global COMPRESSION_ERROR_THRESHOLD
        global FIBER_BATCHES
        COMPRESSION_ERROR_THRESHOLD = error_threshold
        FIBER_BATCHES = fiber_batches

        # logging.debug("Main program using: {} GB".format(round(Utils.mem_usage(print_usage=False), 3)))
        pool = multiprocessing.Pool(processes=nr_processes)

        #Do not pass data in (doubles amount of memory needed), but only idx of shared memory (needs only as much memory as single
        # thread version (only main thread needs memory, others almost 0).
        # Shared memory version also faster (around 20-30%?).
        # Needed otherwise memory problems when processing the raw tracking output (on disk >10GB and in memory >20GB)
        result = pool.map(compress_fibers_worker_shared_mem, range(0, len(fiber_batches)))

        streamlines_c = Utils.flatten(result)
        return streamlines_c
Exemplo n.º 5
0
    def register_mask(mask_data,
                      mask_affine,
                      reference_img,
                      elastic_transform=None,
                      binary_img=True,
                      use_inverse=False):
        '''
        Transform a mask (binary image) with the given elastic_transform

        :param mask_data:            data of the mask that should be transformed
        :param mask_affine:     affine of the mask that should be transformed
        :param reference_img:   a nibabel image to get shape and affine from for the Affine Transformation
        :param elastic_transform:
        :param binary_img:      is input a float image (eg T1) or a binary image (eg a mask)

        :return: transformed mask (a binary Image)
        '''

        logging.debug("mask original shape: {}".format(mask_data.shape))

        # Apply affine for mask image (to t1 space)
        affine_map_inv = AffineMap(
            np.eye(4),
            reference_img.get_data().shape,
            Utils.invert_x_and_y(reference_img.get_affine()), mask_data.shape,
            Utils.invert_x_and_y(mask_affine)
        )  # If I do not use invert_x_and_y for source and target, result is identical
        mask_data_reg = affine_map_inv.transform(mask_data)
        if binary_img:
            mask_data_reg = mask_data_reg > 0
        logging.debug("mask registered shape: {}".format(mask_data_reg.shape))

        if elastic_transform:

            # img = nib.Nifti1Image(mask_data_reg.astype(np.uint8), reference_img.get_affine())
            # nib.save(img, "ROI_registered_before.nii.gz")

            if use_inverse:
                mask_data_reg = elastic_transform.transform_inverse(
                    mask_data_reg)
            else:
                mask_data_reg = elastic_transform.transform(mask_data_reg)

            if binary_img:
                mask_data_reg = mask_data_reg > 0

            # img = nib.Nifti1Image(mask_data_reg.astype(np.uint8), reference_img.get_affine())
            # nib.save(img, "ROI_registered_after.nii.gz")

        else:
            logging.warning(
                "Elastic Transform deactivated; only using Affine Transform")

        if binary_img:
            mask_data_reg = mask_data_reg > 0
        return mask_data_reg
Exemplo n.º 6
0
    def get_cv_fold(fold):
        '''
        Brauche train-test-validate wegen Best-model selection und wegen training von combined net
        :return:
        '''

        #For CV
        if fold == 0:
            train, validate, test = [0, 1, 2], [3], [4]
            # train, validate, test = [0, 1, 2, 3, 4], [3], [4]
        elif fold == 1:
            train, validate, test = [1, 2, 3], [4], [0]
        elif fold == 2:
            train, validate, test = [2, 3, 4], [0], [1]
        elif fold == 3:
            train, validate, test = [3, 4, 0], [1], [2]
        elif fold == 4:
            train, validate, test = [4, 0, 1], [2], [3]

        # subjects = list(Utils.chunks(get_all_subjects()[:100], 10))   #10 folds
        subjects = list(Utils.chunks(get_all_subjects(),
                                     21))  #5 folds a 21 subjects
        # => 5 fold CV ok (score only 1%-point worse than 10 folds (80 vs 60 train subjects) (10 Fold CV impractical!)

        subjects = np.array(subjects)
        return list(subjects[train].flatten()), list(
            subjects[validate].flatten()), list(subjects[test].flatten())
Exemplo n.º 7
0
def get_subjects_chunk(nr_batches, batch_number):
    nr_batches = int(nr_batches)
    batch_number = int(batch_number)

    batch_size = int(math.ceil(len(all_subjects_RAW) / float(nr_batches)))
    res = list(Utils.chunks(all_subjects_RAW, batch_size))
    final_subjects = res[batch_number]
    return final_subjects
Exemplo n.º 8
0
def get_subjects_chunk(nr_batches, batch_number):
    nr_batches = int(nr_batches)
    batch_number = int(batch_number)

    batch_size = int(math.ceil(len(all_subjects_RAW) / float(nr_batches)))
    res = list(Utils.chunks(all_subjects_RAW, batch_size))
    final_subjects = res[batch_number]
    return final_subjects
Exemplo n.º 9
0
    def get_elastic_transform(subject_fa, atlas_fa, subject_path=".."):
        '''
        :param subject_fa: the FA (nibabel img) of a static image of a subject       (static)
        :param atlas_fa:  the FA (nibabel img) of an atlas (Atlas will be warped onto subject)   (moving)

        :return: elastic transformation map
        '''

        if isfile(subject_path + "/FAReg_elastic_transform.pklz"):
            logging.debug("Load existing elastic transform...")
            return Utils.load_pkl_compressed(subject_path +
                                             "/FAReg_elastic_transform.pklz")

        static_img = subject_fa
        static = static_img.get_data()
        moving_img = atlas_fa
        moving = moving_img.get_data()

        # Optional (affine transformation of moving image to static coordinate system) -> needed if on very different ones!
        affine_map = AffineMap(np.eye(4), static.shape,
                               static_img.get_affine(), moving.shape,
                               moving_img.get_affine())
        moving = affine_map.transform(moving)

        start_time = time.time()
        metric = CCMetric(3)
        level_iters = [10, 10, 5]  # better
        # level_iters = [2, 2, 2] #fast -> not much
        sdr = SymmetricDiffeomorphicRegistration(metric, level_iters)
        mapping = sdr.optimize(static, moving)
        # mapping = sdr.optimize(static, moving, Utils.invert_x_and_y(static_img.get_affine()), Utils.invert_x_and_y(moving_img.get_affine())) #not needed
        logging.debug("elastic transform took {0:.2f}s".format(time.time() -
                                                               start_time))

        logging.debug("write elastic transform...")
        Utils.save_pkl_compressed(
            subject_path + "/FAReg_elastic_transform.pklz", mapping)
        return mapping
Exemplo n.º 10
0
def main():
    '''
    This can be used in Shell scripts to get subjects
    '''
    args = sys.argv[1:]
    nr_batches = int(args[0])  # Number of batches
    batch_number = int(args[1])  # Which batch do we want     (idx starts at 0)

    batch_size = int(math.ceil(len(all_subjects_RAW) / float(nr_batches)))
    res = list(Utils.chunks(all_subjects_RAW, batch_size))

    #Note: can not print anyhting, because goes as parameter to script
    # print("Nr of Batches: {} (last batch might be smaller)".format(len(res)))
    # print("Nr of subjects in batch: {}".format(batch_size))
    final_subjects = res[batch_number]
    # print("Subjects: {}".format(final_subjects))

    #To String:
    str = ""
    for subject in final_subjects:
        str += subject + " "
    str = str[:-1]  #remove last space
    print(str)
Exemplo n.º 11
0
def main():
    '''
    This can be used in Shell scripts to get subjects
    '''
    args = sys.argv[1:]
    nr_batches = int(args[0])  # Number of batches
    batch_number = int(args[1])  # Which batch do we want     (idx starts at 0)

    batch_size = int(math.ceil(len(all_subjects_RAW) / float(nr_batches)))
    res = list(Utils.chunks(all_subjects_RAW, batch_size))

    #Note: can not print anyhting, because goes as parameter to script
    # print("Nr of Batches: {} (last batch might be smaller)".format(len(res)))
    # print("Nr of subjects in batch: {}".format(batch_size))
    final_subjects = res[batch_number]
    # print("Subjects: {}".format(final_subjects))

    #To String:
    str = ""
    for subject in final_subjects:
        str += subject + " "
    str = str[:-1]  #remove last space
    print(str)
Exemplo n.º 12
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Exemplo n.º 13
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg