Exemple #1
0
def create_multilabel_mask(Config,
                           subject,
                           labels_type=np.int16,
                           dataset_folder="HCP",
                           labels_folder="bundle_masks"):
    '''
    One-hot encoding of all bundles in one big image
    :param subject:
    :return: image of shape (x, y, z, nr_of_bundles + 1)
    '''
    bundles = exp_utils.get_bundle_names(Config.CLASSES)

    #Masks sind immer HCP_highRes (später erst downsample)
    mask_ml = np.zeros((145, 174, 145, len(bundles)))
    background = np.ones((145, 174, 145))  # everything that contains no bundle

    # first bundle is background -> already considered by setting np.ones in the beginning
    for idx, bundle in enumerate(bundles[1:]):
        mask = nib.load(
            join(C.HOME, dataset_folder, subject, labels_folder,
                 bundle + ".nii.gz"))
        mask_data = mask.get_data()  # dtype: uint8
        mask_ml[:, :, :, idx + 1] = mask_data
        background[mask_data == 1] = 0  # remove this bundle from background

    mask_ml[:, :, :, 0] = background
    return mask_ml.astype(labels_type)
Exemple #2
0
def save_multilabel_img_as_multiple_files_endings(Config, img, affine, path):
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        img_seg = nib.Nifti1Image(img[:, :, :, idx], affine)
        exp_utils.make_dir(join(path, "endings_segmentations"))
        nib.save(img_seg,
                 join(path, "endings_segmentations", bundle + ".nii.gz"))
Exemple #3
0
def calc_peak_dice_onlySeg(Config, y_pred, y_true):
    '''
    Create binary mask of peaks by simple thresholding. Then calculate Dice.

    :param y_pred:
    :param y_true:
    :return:
    '''

    score_per_bundle = {}
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

        # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST)
        #  Resulting dice for 0.1 and 0.2 very similar
        y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2
        y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3

        f1 = f1_score(y_true_binary.flatten(),
                      y_pred_binary.flatten(),
                      average="binary")
        score_per_bundle[bundle] = f1

    return score_per_bundle
Exemple #4
0
def save_multilabel_img_as_multiple_files_endings_OLD(Config,
                                                      img,
                                                      affine,
                                                      path,
                                                      multilabel=True):
    '''
    multilabel True:    save as 1 and 2 without fourth dimension
    multilabel False:   save with beginnings and endings combined
    '''
    # bundles = exp_utils.get_bundle_names("20")[1:]
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0

        multilabel_img = np.zeros(data.shape[:3])

        if multilabel:
            multilabel_img[data[:, :, :, 0]] = 1
            multilabel_img[data[:, :, :, 1]] = 2
        else:
            multilabel_img[data[:, :, :, 0]] = 1
            multilabel_img[data[:, :, :, 1]] = 1

        img_seg = nib.Nifti1Image(multilabel_img, affine)
        exp_utils.make_dir(join(path, "endings"))
        nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
Exemple #5
0
 def test_tractseg_output(self):
     bundles = exp_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         img_ref = nib.load("tests/reference_files/bundle_segmentations/" + bundle + ".nii.gz").get_data()
         img_new = nib.load("examples/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_data()
         images_equal = np.array_equal(img_ref, img_new)
         self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ")")
Exemple #6
0
def calc_peak_length_dice_pytorch(Config, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
    '''
    Ca

    :param y_pred:
    :param y_true:
    :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                             Can be list with several values -> calculate for several thresholds
    :return:
    '''
    import torch
    from tractseg.libs.pytorch_einsum import einsum
    from tractseg.libs import pytorch_utils

    if len(y_pred.shape) == 4:  # 2D
        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)
    else:  # 3D
        y_true = y_true.permute(0, 2, 3, 4, 1)
        y_pred = y_pred.permute(0, 2, 3, 4, 1)

    def angle_last_dim(a, b):
        '''
        Calculate the angle between two nd-arrays (array of vectors) along the last dimension

        without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
        np.arccos -> returns degree in pi (90°: 0.5*pi)

        return: one dimension less then input
        '''
        if len(a.shape) == 4:
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))
        else:
            return torch.abs(einsum('abcde,abcde->abcd', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

    #Single threshold
    score_per_bundle = {}
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[..., (idx * 3):(idx * 3) + 3].contiguous()
        y_true_bund = y_true[..., (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

        angles = angle_last_dim(y_pred_bund, y_true_bund)

        lenghts_pred = torch.norm(y_pred_bund, 2., -1)
        lengths_true = torch.norm(y_true_bund, 2, -1)
        lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true)
        lengths_binary = lengths_binary.view(-1)

        gt_binary = y_true_bund.sum(dim=-1) > 0
        gt_binary = gt_binary.view(-1)  # [bs*x*y]

        angles_binary = angles > max_angle_error[0]
        angles_binary = angles_binary.view(-1)

        combined = lengths_binary * angles_binary

        f1 = pytorch_utils.f1_score_binary(gt_binary, combined)
        score_per_bundle[bundle] = f1
    return score_per_bundle
Exemple #7
0
def plot_tracts(classes, bundle_segmentations, affine, out_dir, brain_mask=None):
    '''
    By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL.
    On the remote Server you can do 'export DISPLAY=":0"' .
    (you should set the value you get if you do 'echo $DISPLAY' if you
    login locally on the remote server). Then all graphics will get rendered locally and not via -X.
    (important: graphical session needs to be running on remote server (e.g. via login locally))
    (important: login needed, not just stay at login screen)
    '''
    from dipy.viz import window
    from tractseg.libs import vtk_utils

    SMOOTHING = 10
    WINDOW_SIZE = (800, 800)
    bundles = ["CST_right", "CA", "IFO_right"]

    renderer = window.Renderer()
    renderer.projection('parallel')

    rows = len(bundles)
    X, Y, Z = bundle_segmentations.shape[:3]
    for j, bundle in enumerate(bundles):
        i = 0  #only one method

        bundle_idx = exp_utils.get_bundle_names(classes)[1:].index(bundle)
        mask_data = bundle_segmentations[:,:,:,bundle_idx]

        if bundle == "CST_right":
            orientation = "axial"
        elif bundle == "CA":
            orientation = "axial"
        elif bundle == "IFO_right":
            orientation = "sagittal"
        else:
            orientation = "axial"

        #bigger: more border
        if orientation == "axial":
            border_y = -100  #-60
        else:
            border_y = -100

        x_current = X * i  # column (width)
        y_current = rows * (Y * 2 + border_y) - (Y * 2 + border_y) * j  # row (height)  (starts from bottom?)

        plot_mask(renderer, mask_data, affine, x_current, y_current,
                            orientation=orientation, smoothing=SMOOTHING, brain_mask=brain_mask)

        #Bundle label
        text_offset_top = -50  # 60
        text_offset_side = -100 # -30
        position = (0 - int(X) + text_offset_side, y_current + text_offset_top, 50)
        text_actor = vtk_utils.label(text=bundle, pos=position, scale=(6, 6, 6), color=(1, 1, 1))
        renderer.add(text_actor)

    renderer.reset_camera()
    window.record(renderer, out_path=join(out_dir, "preview.png"),
                  size=(WINDOW_SIZE[0], WINDOW_SIZE[1]), reset_camera=False, magnification=2)
Exemple #8
0
 def test_peakreg_output(self):
     bundles = exp_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         img_ref = nib.load("tests/reference_files/TOM/" + bundle + ".nii.gz").get_data()
         img_new = nib.load("examples/tractseg_output/TOM/" + bundle + ".nii.gz").get_data()
         # Because of floats small tolerance margin needed
         # Allows for difference up to 0.002 -> still fine
         images_equal = np.allclose(img_ref, img_new, rtol=1e-3, atol=1e-3)
         self.assertTrue(images_equal, "TOMs are not correct (bundle: " + bundle + ")")
Exemple #9
0
def calc_peak_length_dice(Config,
                          y_pred,
                          y_true,
                          max_angle_error=[0.9],
                          max_length_error=0.1):
    '''

    :param y_pred:
    :param y_true:
    :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
    :return:
    '''
    def angle_last_dim(a, b):
        '''
        Calculate the angle between two nd-arrays (array of vectors) along the last dimension

        without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
        np.arccos -> returns degree in pi (90°: 0.5*pi)

        return: one dimension less then input
        '''
        # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
        return abs(
            np.einsum('...i,...i', a, b) /
            (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))

    score_per_bundle = {}
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

        angles = angle_last_dim(y_pred_bund, y_true_bund)

        lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1)
        lengths_true = np.linalg.norm(y_true_bund, axis=-1)
        lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error *
                                                             lengths_true)
        lengths_binary = lengths_binary.flatten()

        gt_binary = y_true_bund.sum(axis=-1) > 0
        gt_binary = gt_binary.flatten()  # [bs*x*y]

        angles_binary = angles > max_angle_error[0]
        angles_binary = angles_binary.flatten()

        combined = lengths_binary * angles_binary

        f1 = my_f1_score(gt_binary, combined)
        score_per_bundle[bundle] = f1
    return score_per_bundle
Exemple #10
0
def save_multilabel_img_as_multiple_files_peaks(Config, img, affine, path, name="TOM"):
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        data = img[:, :, :, (idx*3):(idx*3)+3]

        if Config.FLIP_OUTPUT_PEAKS:
            data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
            filename = bundle + "_f.nii.gz"
        else:
            filename = bundle + ".nii.gz"

        img_seg = nib.Nifti1Image(data, affine)
        exp_utils.make_dir(join(path, name))
        nib.save(img_seg, join(path, name, filename))
Exemple #11
0
def calc_peak_dice(Config, y_pred, y_true, max_angle_error=[0.9]):
    '''

    :param y_pred:
    :param y_true:
    :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
    :return:
    '''
    def angle(a, b):
        '''
        Calculate the angle between two 1d-arrays (2 vectors) along the last dimension

        without anything further: 1->0°, 0.7->45°, 0->90°
        np.arccos -> returns degree in pi (90°: 0.5*pi)
        '''
        return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

    def angle_last_dim(a, b):
        '''
        Calculate the angle between two nd-arrays (array of vectors) along the last dimension

        without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
        np.arccos -> returns degree in pi (90°: 0.5*pi)

        return: one dimension less then input
        '''
        # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
        return abs(
            np.einsum('...i,...i', a, b) /
            (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))

    score_per_bundle = {}
    bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

        angles = angle_last_dim(y_pred_bund, y_true_bund)
        angles_binary = angles > max_angle_error[0]

        gt_binary = y_true_bund.sum(axis=-1) > 0

        f1 = f1_score(gt_binary.flatten(),
                      angles_binary.flatten(),
                      average="binary")
        score_per_bundle[bundle] = f1

    return score_per_bundle
Exemple #12
0
 def test_tractseg_output_SR_noPP(self):
     bundles = exp_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         # IFO somehow very different on travis than locally. Unclear why. All other bundles are fine.
         if bundle != "IFO_right":
             img_ref = nib.load("tests/reference_files/bundle_segmentations_SR_noPP/" + bundle + ".nii.gz").get_data()
             img_new = nib.load("examples/SR_noPP/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_data()
             # Processing on travis slightly different from local environment -> have to allow for small margin
             # images_equal = np.array_equal(img_ref, img_new)
             nr_differing_voxels = np.abs(img_ref - img_new).sum()
             if nr_differing_voxels < 5:
                 images_equal = True
             else:
                 images_equal = False
             self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ") " +
                                           "(nr of differing voxels: " + str(nr_differing_voxels) + ")")
Exemple #13
0
class Config(TractSegConfig):
    EXP_NAME = os.path.basename(__file__).split(".")[0]

    DATASET_FOLDER = "HCP_preproc_all"
    NR_OF_GRADIENTS = 18
    FEATURES_FILENAME = "32g270g_BX"
    P_SAMP = 0.4

    CLASSES = "AutoPTX_42"
    NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:])

    DATASET = "HCP_all"

    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"
    LR_SCHEDULE_PATIENCE = 20

    NUM_EPOCHS = 200  # 130 probably also fine
Exemple #14
0
class Config(DmRegConfig):
    EXP_NAME = os.path.basename(__file__).split(".")[0]

    DATASET_FOLDER = "HCP_preproc_all"
    NR_OF_GRADIENTS = 18
    FEATURES_FILENAME = "32g270g_BX"
    P_SAMP = 0.4

    CLASSES = "AutoPTX_42"
    NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:])

    # THRESHOLD = 0.001  # Final DM wil be thresholded at this value
    THRESHOLD = 0.0001  # use lower value so user has more choice

    DATASET = "HCP_all"

    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"
    LR_SCHEDULE_PATIENCE = 20

    NUM_EPOCHS = 200  # 130 probably also fine
Exemple #15
0
 def test_bundle_names(self):
     bundles = exp_utils.get_bundle_names("CST_right")
     self.assertListEqual(bundles, ["BG", "CST_right"],
                          "Error in list of bundle names")
Exemple #16
0
def plot_tracts_matplotlib(classes,
                           bundle_segmentations,
                           background_img,
                           out_dir,
                           threshold=0.001,
                           exp_type="tract_segmentation"):
    def plot_single_tract(bg, data, orientation, bundle, exp_type):
        if orientation == "coronal":
            data = data.transpose(
                2, 0, 1,
                3) if exp_type == "peak_regression" else data.transpose(
                    2, 0, 1)
            data = data[::-1, :, :]
            bg = bg.transpose(2, 0, 1)[::-1, :, :]
        elif orientation == "sagittal":
            data = data.transpose(
                2, 1, 0,
                3) if exp_type == "peak_regression" else data.transpose(
                    2, 1, 0)
            data = data[::-1, :, :]
            bg = bg.transpose(2, 1, 0)[::-1, :, :]
        else:  # axial
            pass

        mask_voxel_coords = np.where(data != 0)
        if len(mask_voxel_coords) > 2 and len(mask_voxel_coords[2]) > 0:
            minidx = int(np.min(mask_voxel_coords[2]))
            maxidx = int(np.max(mask_voxel_coords[2])) + 1
            mean_slice = int(np.mean([minidx, maxidx]))
        else:
            mean_slice = int(bg.shape[2] / 2)
        bg = bg[:, :, mean_slice]
        # bg = matplotlib.colors.Normalize()(bg)

        # project 3D to 2D image
        # todo: this kind of projection not sensible for peak images
        if aggregation == "mean":
            data = data.mean(axis=2)
        else:
            data = data.max(axis=2)

        plt.imshow(bg, cmap="gray")
        data = np.ma.masked_where(data < 0.00001, data)
        plt.imshow(data,
                   cmap="autumn")  # even with cmap=autumn peaks still RGB
        plt.title(bundle, fontsize=7)

    if classes.startswith("AutoPTX"):
        bundles = ["cst_r", "cst_s_r", "ifo_r", "fx_l", "fx_r", "or_l", "fma"]
    else:
        bundles = [
            "CST_right", "CST_s_right", "CA", "IFO_right", "FX_left",
            "FX_right", "OR_left", "CC_1"
        ]

    if exp_type == "peak_regression":
        s = bundle_segmentations.shape
        bundle_segmentations = bundle_segmentations.reshape(
            [s[0], s[1], s[2], int(s[3] / 3), 3])
        print(bundle_segmentations.shape)
        bundles = ["CST_right", "CST_s_right", "CA", "CC_1",
                   "AF_left"]  # can only use bundles from part1

    aggregation = "max"
    cols = 4
    rows = math.ceil(len(bundles) / cols)

    background_img = background_img[..., 0]

    for j, bundle in enumerate(bundles):
        bun = bundle.lower()
        if bun.startswith("ca") or bun.startswith("fx_") or bun.startswith("or_") or \
                bun.startswith("cc_1") or bun.startswith("fma"):
            orientation = "axial"
        elif bun.startswith("ifo_") or bun.startswith("icp_") or bun.startswith("cst_s_") or \
                bun.startswith("af_"):
            bundle = bundle.replace("_s", "")
            orientation = "sagittal"
        elif bun.startswith("cst_"):
            orientation = "coronal"
        else:
            raise ValueError("invalid bundle")

        bundle_idx = exp_utils.get_bundle_names(classes)[1:].index(bundle)
        mask_data = bundle_segmentations[:, :, :, bundle_idx]
        mask_data = np.copy(
            mask_data
        )  # copy data otherwise will also threshold data outside of plot function
        # mask_data[mask_data < threshold] = 0
        mask_data[
            mask_data <
            0.001] = 0  # higher value better for preview, otherwise half of image just red

        plt.subplot(rows, cols, j + 1)
        plt.axis("off")
        plot_single_tract(background_img,
                          mask_data,
                          orientation,
                          bundle,
                          exp_type=exp_type)

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig(join(out_dir, "preview.png"), bbox_inches='tight', dpi=300)
Exemple #17
0
def calc_peak_dice_pytorch(Config, y_pred, y_true, max_angle_error=[0.9]):
    '''
    Calculate angle between groundtruth and prediction and keep the voxels where
    angle is smaller than MAX_ANGLE_ERROR.

    From groundtruth generate a binary mask by selecting all voxels with len > 0.

    Calculate Dice from these 2 masks.

    -> Penalty on peaks outside of tract or if predicted peak=0
    -> no penalty on very very small with right direction -> bad
    => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

    :param y_pred:
    :param y_true:
    :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                             Can be list with several values -> calculate for several thresholds
    :return:
    '''
    import torch
    from tractseg.libs.pytorch_einsum import einsum
    from tractseg.libs import pytorch_utils

    y_true = y_true.permute(0, 2, 3, 1)
    y_pred = y_pred.permute(0, 2, 3, 1)

    def angle_last_dim(a, b):
        '''
        Calculate the angle between two nd-arrays (array of vectors) along the last dimension

        without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
        np.arccos -> returns degree in pi (90°: 0.5*pi)

        return: one dimension less then input
        '''
        return torch.abs(
            einsum('abcd,abcd->abc', a, b) /
            (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

    #Single threshold
    if len(max_angle_error) == 1:
        score_per_bundle = {}
        bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
            score_per_bundle[bundle] = f1

        return score_per_bundle

    #multiple thresholds
    else:
        score_per_bundle = {}
        bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            score_per_bundle[bundle] = []
            for threshold in max_angle_error:
                angles_binary = angles > threshold
                angles_binary = angles_binary.view(-1)

                f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle].append(f1)

        return score_per_bundle
Exemple #18
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_threshold=False,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX)
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    config = get_config_name(input_type,
                             output_type,
                             dropout_sampling=dropout_sampling)
    Config = getattr(
        importlib.import_module("tractseg.experiments.pretrained_models." +
                                config), "Config")()
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config)

    if bundle_specific_threshold:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    elif input_type == "peaks":
        if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR,
                "pretrained_weights_tract_segmentation_dropout_v2.npz")
        elif Config.EXPERIMENT_TYPE == "tract_segmentation":
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz")
            # Config.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMirFlip", "best_weights_ep247.npz")
        elif Config.EXPERIMENT_TYPE == "endings_segmentation":
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR,
                "pretrained_weights_endings_segmentation_v3.npz")
        elif Config.EXPERIMENT_TYPE == "dm_regression":
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz")
    elif input_type == "T1":
        if Config.EXPERIMENT_TYPE == "tract_segmentation":
            Config.WEIGHTS_PATH = join(
                C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif Config.EXPERIMENT_TYPE == "endings_segmentation":
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR,
                "pretrained_weights_endings_segmentation_v1.npz")
        elif Config.EXPERIMENT_TYPE == "peak_regression":
            Config.WEIGHTS_PATH = join(
                C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz")

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)
    # brain_mask = img_utils.simple_brain_mask(data)
    # if Config.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data,
                          (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data)
    data, transformation = dataset_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0])

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            exp_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING)
        model = BaseModel(Config)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=True,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
            else:
                seg, img_y = trainer.predict_img(
                    Config,
                    model,
                    data_loder_inference,
                    probs=False,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
        else:
            seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v1.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v1.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v1.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v1.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part)
            data_loder_inference = DataLoaderInference(Config, data=data)
            model = BaseModel(Config)
            seg, img_y = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                exp_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

        #quite fast
        if bundle_specific_threshold:
            seg = img_utils.remove_small_peaks_bundle_specific(
                seg,
                exp_utils.get_bundle_names(Config.CLASSES)[1:],
                len_thr=0.3)
        else:
            seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold)

        #3 dir for Peaks -> bad results
        # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data,
        #                                                                scale_to_world_shape=False,
        #                                                                only_prediction=True,
        #                                                                batch_size=inference_batch_size)
        # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation":
        seg = img_utils.probs_to_binary_bundle_specific(
            seg,
            exp_utils.get_bundle_names(Config.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = dataset_utils.cut_and_scale_img_back_to_original_img(
        seg, transformation)  # quite slow
    seg = dataset_utils.add_original_zero_padding_again(
        seg, bbox, original_shape, Config.NR_OF_CLASSES)  # quite slow

    if postprocess:
        seg = img_utils.postprocess_segmentations(seg,
                                                  blob_thr=blob_size_thr,
                                                  hole_closing=2)

    exp_utils.print_verbose(
        Config, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg
Exemple #19
0
class Config:
    """Settings and Hyperparameters"""
    EXP_MULTI_NAME = ""  #CV Parent Dir name; leave empty for Single Bundle Experiment
    EXP_NAME = "HCP_TEST"
    MODEL = "UNet_Pytorch_DeepSup"
    # tract_segmentation / endings_segmentation / dm_regression / peak_regression
    EXPERIMENT_TYPE = "tract_segmentation"

    DIM = "2D"  # 2D / 3D
    NUM_EPOCHS = 250
    EPOCH_MULTIPLIER = 1 #2D: 1, 3D: 12 for lowRes, 3 for highRes
    DATA_AUGMENTATION = True
    DAUG_SCALE = True
    DAUG_NOISE = True
    DAUG_NOISE_VARIANCE = (0, 0.05)
    DAUG_ELASTIC_DEFORM = True
    DAUG_ALPHA = (90., 120.)
    DAUG_SIGMA = (9., 11.)
    DAUG_RESAMPLE = False   # does not change validation dice (if using Gaussian_blur) -> deactivate
    DAUG_RESAMPLE_LEGACY = False    # does not change validation dice (at least on AutoPTX) -> deactivate
    DAUG_GAUSSIAN_BLUR = True
    DAUG_BLUR_SIGMA = (0, 1)
    DAUG_ROTATE = False
    DAUG_ROTATE_ANGLE = (-0.2, 0.2)  # rotation: 2*np.pi = 360 degree  (-> 0.4 ~ 22 degree, 0.2 ~ 11 degree))
    DAUG_MIRROR = False
    DAUG_FLIP_PEAKS = False
    SPATIAL_TRANSFORM = "SpatialTransform"  # SpatialTransform / SpatialTransformPeaks
    # 1.0 slightly less overfitting than 0.4 but not much ("break-even" 20epochs later)
    # 1.0: CPU bottleneck, 0.4: CPU not 100% all the time anymore, but still GPU utility not 100%
    # 1.0: clearly more complete CA+FX on nonHCP than 0.2
    P_SAMP = 1.0  # use 1.0 for final model
    DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - " \
                "DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)"
    DATASET = "HCP"  # HCP / HCP_32g / Schizo
    RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm)
    # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz
    FEATURES_FILENAME = "12g90g270g"

    LABELS_FILENAME = ""        # autofilled
    LOSS_FUNCTION = "default"   # default / soft_batch_dice
    OPTIMIZER = "Adamax"
    CLASSES = "All"             # All / 11 / 20 / CST_right
    NR_OF_GRADIENTS = 9
    NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:])
    # NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(CLASSES)[1:])

    INPUT_DIM = None  # (80, 80) / (144, 144)
    LOSS_WEIGHT = None  # None: no weighting
    LOSS_WEIGHT_LEN = -1  # -1: constant over all epochs
    SLICE_DIRECTION = "y"  # x, y, z  (combined needs z)
    TRAINING_SLICE_DIRECTION = "xyz"    # y / xyz
    INFO = "-"
    BATCH_NORM = False
    WEIGHT_DECAY = 0
    USE_DROPOUT = False
    DROPOUT_SAMPLING = False
    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset"
    DATASET_FOLDER = "HCP_preproc" # HCP / Schizo
    LABELS_FOLDER = "bundle_masks"  # bundle_masks / bundle_masks_dm
    MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME)
    EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME)  # default path
    BATCH_SIZE = 47  #Peak Prediction: 44 #Pytorch: 50  #Lasagne: 56  #Lasagne combined: 42  #Pytorch UpSample: 56
    LEARNING_RATE = 0.001  # 0.002 #LR find: 0.000143 ?  # 0.001
    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"    # "min" / "max"
    LR_SCHEDULE_PATIENCE = 20
    UNET_NR_FILT = 64
    LOAD_WEIGHTS = False
    # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz")
    WEIGHTS_PATH = ""  # if empty string: autoloading the best_weights in get_best_weights_path()
    TYPE = "single_direction"  # single_direction / combined
    CV_FOLD = 0
    VALIDATE_SUBJECTS = []
    TRAIN_SUBJECTS = []
    TEST_SUBJECTS = []
    TRAIN = True
    TEST = True
    SEGMENT = False
    GET_PROBS = False
    OUTPUT_MULTIPLE_FILES = False
    RESET_LAST_LAYER = False
    UPSAMPLE_TYPE = "bilinear"  # bilinear / nearest
    BEST_EPOCH_SELECTION = "f1"  # f1 / loss
    METRIC_TYPES = ["loss", "f1_macro"]
    FP16 = True

    # Peak_regression specific
    PEAK_DICE_THR = [0.95]
    PEAK_DICE_LEN_THR = 0.05
    FLIP_OUTPUT_PEAKS = True    # flip peaks along z axis to make them compatible with MITK

    # For TractSeg.py application
    PREDICT_IMG = False
    PREDICT_IMG_OUTPUT = None
    TRACTSEG_DIR = "tractseg_output"
    KEEP_INTERMEDIATE_FILES = False
    CSD_RESOLUTION = "LOW"  # HIGH / LOW
    NR_CPUS = -1

    # Rarly changed:
    LABELS_TYPE = "int"
    THRESHOLD = 0.5  # Binary: 0.5, Regression: 0.01 ?
    TEST_TIME_DAUG = False
    USE_VISLOGGER = False  #only works with Python 3
    SAVE_WEIGHTS = True
    SEG_INPUT = "Peaks"  # Gradients/ Peaks
    NR_SLICES = 1  # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction()
    PRINT_FREQ = 20  #20
    NORMALIZE_DATA = True
    NORMALIZE_PER_CHANNEL = False
    BEST_EPOCH = 0
    VERBOSE = True
    CALC_F1 = True
    ONLY_VAL = False
Exemple #20
0
class Config:
    """Settings and Hyperparameters"""
    EXP_MULTI_NAME = ""  #CV Parent Dir name; leave empty for Single Bundle Experiment
    EXP_NAME = "HCP_TEST"
    MODEL = "UNet_Pytorch"
    # tract_segmentation / endings_segmentation / dm_regression / peak_regression
    EXPERIMENT_TYPE = "tract_segmentation"

    DIM = "2D"  # 2D / 3D
    NUM_EPOCHS = 250
    EPOCH_MULTIPLIER = 1 #2D: 1, 3D: 12 for lowRes, 3 for highRes
    DATA_AUGMENTATION = False
    DAUG_SCALE = True
    DAUG_NOISE = True
    DAUG_ELASTIC_DEFORM = True
    DAUG_RESAMPLE = True
    DAUG_ROTATE = False
    DAUG_MIRROR = False
    DAUG_FLIP_PEAKS = False
    DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - " \
                "DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)"
    DATASET = "HCP"  # HCP / HCP_32g / Schizo
    RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm)
    # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz
    FEATURES_FILENAME = "12g90g270g"

    LABELS_FILENAME = ""        # autofilled
    LOSS_FUNCTION = "default"   # default / soft_batch_dice
    OPTIMIZER = "Adamax"
    CLASSES = "All"             # All / 11 / 20 / CST_right
    NR_OF_GRADIENTS = 9
    NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:])
    # NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(CLASSES)[1:])

    INPUT_DIM = None  # (80, 80) / (144, 144)
    LOSS_WEIGHT = 1  # 1: no weighting
    LOSS_WEIGHT_LEN = -1  # -1: constant over all epochs
    SLICE_DIRECTION = "y"  # x, y, z  (combined needs z)
    TRAINING_SLICE_DIRECTION = "xyz"    # y / xyz
    INFO = "-"
    BATCH_NORM = False
    WEIGHT_DECAY = 0
    USE_DROPOUT = False
    DROPOUT_SAMPLING = False
    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset"
    DATASET_FOLDER = "HCP" # HCP / Schizo
    LABELS_FOLDER = "bundle_masks"  # bundle_masks / bundle_masks_dm
    MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME)
    EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME)  # default path
    BATCH_SIZE = 47  #Peak Prediction: 44 #Pytorch: 50  #Lasagne: 56  #Lasagne combined: 42  #Pytorch UpSample: 56
    LEARNING_RATE = 0.001  # 0.002 #LR find: 0.000143 ?  # 0.001
    LR_SCHEDULE = False
    UNET_NR_FILT = 64
    LOAD_WEIGHTS = False
    # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz")
    WEIGHTS_PATH = ""  # if empty string: autoloading the best_weights in get_best_weights_path()
    TYPE = "single_direction"  # single_direction / combined
    CV_FOLD = 0
    VALIDATE_SUBJECTS = []
    TRAIN_SUBJECTS = []
    TEST_SUBJECTS = []
    TRAIN = True
    TEST = True
    SEGMENT = False
    GET_PROBS = False
    OUTPUT_MULTIPLE_FILES = False
    RESET_LAST_LAYER = False
    UPSAMPLE_TYPE = "bilinear"  # bilinear / nearest

    # Peak_regression specific
    PEAK_DICE_THR = [0.95]
    PEAK_DICE_LEN_THR = 0.05
    FLIP_OUTPUT_PEAKS = True    # flip peaks along z axis to make them compatible with MITK

    # For TractSeg.py application
    PREDICT_IMG = False
    PREDICT_IMG_OUTPUT = None
    TRACTSEG_DIR = "tractseg_output"
    KEEP_INTERMEDIATE_FILES = False
    CSD_RESOLUTION = "LOW"  # HIGH / LOW
    NR_CPUS = -1

    # Rarly changed:
    LABELS_TYPE = np.int16  # Binary: np.int16, Regression: np.float32
    THRESHOLD = 0.5  # Binary: 0.5, Regression: 0.01 ?
    TEST_TIME_DAUG = False
    USE_VISLOGGER = False  #only works with Python 3
    SAVE_WEIGHTS = True
    SEG_INPUT = "Peaks"  # Gradients/ Peaks
    NR_SLICES = 1  # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction()
    PRINT_FREQ = 20  #20
    NORMALIZE_DATA = True
    NORMALIZE_PER_CHANNEL = False
    BEST_EPOCH = 0
    VERBOSE = True
    CALC_F1 = True
Exemple #21
0
def run_tractseg(data, output_type="tract_segmentation",
                 single_orientation=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1,
                 postprocess=False, peak_regression_part="All", input_type="peaks",
                 blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None,
                 inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False,
                 tract_segmentations_path=None, TOM_dilation=1):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX
        tract_segmentations_path: todo
        TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs.

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition)
        Config = getattr(importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(join(C.EXP_PATH,
                                                     exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                                                     "Hyperparameters.txt"))

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_dropout_v2.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                                               "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz")
        else:  # AutoPTX
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_aPTX_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_dm_regression_aPTX_v1.npz")
            else:
                raise ValueError("bundle_definition AutoPTX not supported in combination with this output type")
            #todo: remove when aPTX weights are loaded automatically
            if not os.path.exists(Config.WEIGHTS_PATH):
                raise FileNotFoundError("Could not find weights file: {}".format(Config.WEIGHTS_PATH))

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)
    # brain_mask = img_utils.simple_brain_mask(data)
    # if Config.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    #runtime on HCP data: 0.9s
    data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data)
    # runtime on HCP data: 0.5s
    data, transformation = dataset_utils.pad_and_scale_img_to_square_img(data, target_size=Config.INPUT_DIM[0],
                                                                         nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(exp_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(experiment_type=Config.EXPERIMENT_TYPE,
                                          dropout_sampling=Config.DROPOUT_SAMPLING)
        model = BaseModel(Config, inference=True)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True,
                                                 scale_to_world_shape=False, only_prediction=True,
                                                 batch_size=inference_batch_size)
            else:
                seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False,
                                                 scale_to_world_shape=False, only_prediction=True,
                                                 batch_size=inference_batch_size)
        else:
            seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(Config, model, data=data,
                                                                           scale_to_world_shape=False,
                                                                           only_prediction=True,
                                                                           batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v2.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v2.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v2.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v2.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(experiment_type=Config.EXPERIMENT_TYPE,
                                              dropout_sampling=Config.DROPOUT_SAMPLING, part=part)
            model = BaseModel(Config, inference=True)

            if single_orientation:
                data_loder_inference = DataLoaderInference(Config, data=data)
                seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True,
                                                 scale_to_world_shape=False, only_prediction=True,
                                                 batch_size=inference_batch_size)
            else:
                # 3 dir for Peaks -> bad results
                seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(Config, model, data=data,
                                                                                  scale_to_world_shape=False,
                                                                                  only_prediction=True,
                                                                                  batch_size=inference_batch_size)
                seg = direction_merger.mean_fusion_peaks(seg_xyz)

            if peak_regression_part == "All":
                seg_all[:, :, :, (idx*Config.NR_OF_CLASSES) : (idx*Config.NR_OF_CLASSES+Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

        #quite fast
        # if bundle_specific_threshold:
        #     seg = peak_utils.remove_small_peaks_bundle_specific(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:],
        #                                                        len_thr=0.3)
        # else:
        #     seg = peak_utils.remove_small_peaks(seg, len_thr=peak_threshold)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        # Runtime ~4s
        seg = img_utils.bundle_specific_postprocessing(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:])

    # runtime on HCP data: 5.1s
    seg = dataset_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus)
    # runtime on HCP data: 1.6s
    seg = dataset_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        seg = peak_utils.mask_and_normalize_peaks(seg, tract_segmentations_path,
                                                  exp_utils.get_bundle_names(Config.CLASSES)[1:],
                                                  TOM_dilation, nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess:
        # Runtime ~7s for 1.25mm resolution
        # Runtime ~1.5s for  2mm resolution
        st = time.time()
        seg = img_utils.postprocess_segmentations(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:],
                                                  blob_thr=blob_size_thr, hole_closing=None)
        print("took: {}".format(time.time() - st))

    exp_utils.print_verbose(Config, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg