コード例 #1
0
ファイル: Layers.py プロジェクト: doctoryfx/TractSeg
def soft_dice_paul(HP, idxs, marker, preds, ys):
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return 1 - (dice / n_classes)
コード例 #2
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
コード例 #3
0
ファイル: Layers.py プロジェクト: doctoryfx/TractSeg
def theano_f1_score_OLD(HP, idxs, marker, preds, ys):
    '''
    Von Paul
    '''
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        pred = T.gt(pred, T.constant(0.5))
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return dice / n_classes
コード例 #4
0
ファイル: ImgUtils.py プロジェクト: doctoryfx/TractSeg
    def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path):
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx*3):(idx*3)+3]

            if HP.FLIP_OUTPUT_PEAKS:
                data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
                filename = bundle + "_f.nii.gz"
            else:
                filename = bundle + ".nii.gz"

            img_seg = nib.Nifti1Image(data, affine)
            ExpUtils.make_dir(join(path, "TOM"))
            nib.save(img_seg, join(path, "TOM", filename))
コード例 #5
0
ファイル: ImgUtils.py プロジェクト: doctoryfx/TractSeg
    def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path):
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 3):(idx * 3) + 3]

            if HP.FLIP_OUTPUT_PEAKS:
                data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
                filename = bundle + "_f.nii.gz"
            else:
                filename = bundle + ".nii.gz"

            img_seg = nib.Nifti1Image(data, affine)
            ExpUtils.make_dir(join(path, "TOM"))
            nib.save(img_seg, join(path, "TOM", filename))
コード例 #6
0
ファイル: MetricUtils.py プロジェクト: silongGG/TractSeg
    def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''
        def angle(a, b):
            '''
            Calculate the angle between two 1d-arrays (2 vectors) along the last dimension

            without anything further: 1->0°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)
            '''
            return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(
                np.einsum('...i,...i', a, b) /
                (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) +
                 1e-7))

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            angles_binary = angles > max_angle_error[0]

            gt_binary = y_true_bund.sum(axis=-1) > 0

            f1 = f1_score(gt_binary.flatten(),
                          angles_binary.flatten(),
                          average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
コード例 #7
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_length_dice(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))


        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1)
            lengths_true = np.linalg.norm(y_true_bund, axis=-1)
            lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.flatten()

            gt_binary = y_true_bund.sum(axis=-1) > 0
            gt_binary = gt_binary.flatten()  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.flatten()

            combined = lengths_binary * angles_binary

            f1 = MetricUtils.my_f1_score(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
コード例 #8
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''

        def angle(a, b):
            '''
            Calculate the angle between two 1d-arrays (2 vectors) along the last dimension

            without anything further: 1->0°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)
            '''
            return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))


        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            angles_binary = angles > max_angle_error[0]

            gt_binary = y_true_bund.sum(axis=-1) > 0

            f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
コード例 #9
0
    def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"):
        '''
        One-hot encoding of all bundles in one big image
        :param subject:
        :return: image of shape (x, y, z, nr_of_bundles + 1)
        '''
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)

        #Masks sind immer HCP_highRes (später erst downsample)
        mask_ml = np.zeros((145, 174, 145, len(bundles)))
        background = np.ones((145, 174, 145))   # everything that contains no bundle

        for idx, bundle in enumerate(bundles[1:]):   #first bundle is background -> already considered by setting np.ones in the beginning
            mask = nib.load(join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz"))
            mask_data = mask.get_data()     # dtype: uint8
            mask_ml[:, :, :, idx+1] = mask_data
            background[mask_data == 1] = 0    # remove this bundle from background

        mask_ml[:, :, :, 0] = background
        return mask_ml.astype(labels_type)
コード例 #10
0
ファイル: ImgUtils.py プロジェクト: doctoryfx/TractSeg
    def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"):
        '''
        One-hot encoding of all bundles in one big image
        :param subject:
        :return: image of shape (x, y, z, nr_of_bundles + 1)
        '''
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)

        #Masks sind immer HCP_highRes (später erst downsample)
        mask_ml = np.zeros((145, 174, 145, len(bundles)))
        background = np.ones((145, 174, 145))   # everything that contains no bundle

        for idx, bundle in enumerate(bundles[1:]):   #first bundle is background -> already considered by setting np.ones in the beginning
            mask = nib.load(join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz"))
            mask_data = mask.get_data()     # dtype: uint8
            mask_ml[:, :, :, idx+1] = mask_data
            background[mask_data == 1] = 0    # remove this bundle from background

        mask_ml[:, :, :, 0] = background
        return mask_ml.astype(labels_type)
コード例 #11
0
    def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True):
        '''
        multilabel True:    save as 1 and 2 without fourth dimension
        multilabel False:   save with beginnings and endings combined
        '''
        # bundles = ExpUtils.get_bundle_names("20")[1:]
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0

            multilabel_img = np.zeros(data.shape[:3])

            if multilabel:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 2
            else:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 1

            img_seg = nib.Nifti1Image(multilabel_img, affine)
            ExpUtils.make_dir(join(path, "endings"))
            nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
コード例 #12
0
ファイル: ImgUtils.py プロジェクト: doctoryfx/TractSeg
    def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True):
        '''
        multilabel True:    save as 1 and 2 without fourth dimension
        multilabel False:   save with beginnings and endings combined
        '''
        # bundles = ExpUtils.get_bundle_names("20")[1:]
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0

            multilabel_img = np.zeros(data.shape[:3])

            if multilabel:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 2
            else:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 1

            img_seg = nib.Nifti1Image(multilabel_img, affine)
            ExpUtils.make_dir(join(path, "endings"))
            nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
コード例 #13
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_dice_onlySeg(HP, y_pred, y_true):
        '''
        Create binary mask of peaks by simple thresholding. Then calculate Dice.

        :param y_pred:
        :param y_true:
        :return:
        '''

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST)
            #  Resulting dice for 0.1 and 0.2 very similar
            y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2
            y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3

            f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
コード例 #14
0
ファイル: MetricUtils.py プロジェクト: silongGG/TractSeg
    def calc_peak_length_dice_pytorch(HP,
                                      y_pred,
                                      y_true,
                                      max_angle_error=[0.9],
                                      max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
                max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
コード例 #15
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
コード例 #16
0
 def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path):
     bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
     for idx, bundle in enumerate(bundles):
         img_seg = nib.Nifti1Image(img[:,:,:,idx], affine)
         ExpUtils.make_dir(join(path, "endings_segmentations"))
         nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
コード例 #17
0
ファイル: PlotUtils.py プロジェクト: qiaotian/TractSeg
    def plot_tracts(HP,
                    bundle_segmentations,
                    affine,
                    out_dir,
                    brain_mask=None):
        '''
        By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL.
        On the remote Server you can do 'export DISPLAY=":0"' (you should set the value you get if you do 'echo $DISPLAY' if you
        login locally on the remote server). Then all graphics will get rendered locally and not via -X.
        (important: graphical session needs to be running on remote server (e.g. via login locally))
        (important: login needed, not just stay at login screen)
        '''
        from dipy.viz import window
        from tractseg.libs.VtkUtils import VtkUtils

        SMOOTHING = 10
        WINDOW_SIZE = (800, 800)
        bundles = ["CST_right", "CA", "IFO_right"]

        renderer = window.Renderer()
        renderer.projection('parallel')

        rows = len(bundles)
        X, Y, Z = bundle_segmentations.shape[:3]
        for j, bundle in enumerate(bundles):
            i = 0  #only one method

            bundle_idx = ExpUtils.get_bundle_names(
                HP.CLASSES)[1:].index(bundle)
            mask_data = bundle_segmentations[:, :, :, bundle_idx]

            if bundle == "CST_right":
                orientation = "axial"
            elif bundle == "CA":
                orientation = "axial"
            elif bundle == "IFO_right":
                orientation = "sagittal"
            else:
                orientation = "axial"

            #bigger: more border
            if orientation == "axial":
                border_y = -100  #-60
            else:
                border_y = -100

            x_current = X * i  # column (width)
            y_current = rows * (Y * 2 + border_y) - (
                Y * 2 + border_y) * j  # row (height)  (starts from bottom?)

            PlotUtils.plot_mask(renderer,
                                mask_data,
                                affine,
                                x_current,
                                y_current,
                                orientation=orientation,
                                smoothing=SMOOTHING,
                                brain_mask=brain_mask)

            #Bundle label
            text_offset_top = -50  # 60
            text_offset_side = -100  # -30
            position = (0 - int(X) + text_offset_side,
                        y_current + text_offset_top, 50)
            text_actor = VtkUtils.label(text=bundle,
                                        pos=position,
                                        scale=(6, 6, 6),
                                        color=(1, 1, 1))
            renderer.add(text_actor)

        renderer.reset_camera()
        window.record(renderer,
                      out_path=join(out_dir, "preview.png"),
                      size=(WINDOW_SIZE[0], WINDOW_SIZE[1]),
                      reset_camera=False,
                      magnification=2)
コード例 #18
0
ファイル: BaseHP.py プロジェクト: silongGG/TractSeg
class HP:
    EXP_MULTI_NAME = ""  #CV Parent Dir name # leave empty for Single Bundle Experiment
    EXP_NAME = "HCP_TEST"  # HCP_TEST
    MODEL = "UNet_Pytorch"  # UNet_Lasagne / UNet_Pytorch
    EXPERIMENT_TYPE = "tract_segmentation"  # tract_segmentation / endings_segmentation / dm_regression / peak_regression

    NUM_EPOCHS = 250
    DATA_AUGMENTATION = False
    DAUG_SCALE = True
    DAUG_NOISE = True
    DAUG_ELASTIC_DEFORM = True
    DAUG_RESAMPLE = True
    DAUG_ROTATE = False
    DAUG_MIRROR = False
    DAUG_FLIP_PEAKS = False
    DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)"
    DATASET = "HCP"  # HCP / HCP_32g / Schizo
    RESOLUTION = "1.25mm"  # 1.25mm (/ 2.5mm)
    FEATURES_FILENAME = "12g90g270g"  # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz
    LABELS_FILENAME = ""  # autofilled      #"bundle_peaks/CA"  #IMPORTANT: Adapt BatchGen if 808080              # bundle_masks / bundle_masks_72 / bundle_masks_dm / bundle_peaks      #Only used when using DataManagerNifti
    LOSS_FUNCTION = "default"  # default / soft_batch_dice
    OPTIMIZER = "Adamax"
    CLASSES = "All"  # All / 11 / 20 / CST_right
    NR_OF_GRADIENTS = 9
    NR_OF_CLASSES = len(ExpUtils.get_bundle_names(CLASSES)[1:])
    # NR_OF_CLASSES = 3 * len(ExpUtils.get_bundle_names(CLASSES)[1:])

    INPUT_DIM = (144, 144)  # (80, 80) / (144, 144)
    LOSS_WEIGHT = 1  # 1: no weighting
    LOSS_WEIGHT_LEN = -1  # -1: constant over all epochs
    SLICE_DIRECTION = "y"  # x, y, z  (combined needs z)
    TRAINING_SLICE_DIRECTION = "xyz"  # y / xyz
    INFO = "-"  # Dropout, Deconv, 11bundles, LeakyRelu, PeakDiceThres=0.9
    BATCH_NORM = False
    WEIGHT_DECAY = 0
    USE_DROPOUT = False
    DROPOUT_SAMPLING = False

    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset"  # HCP / HCP_batches/XXX / TRACED / HCP_fusion_npy_270g_125mm / HCP_fusion_npy_32g_25mm
    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
    DATASET_FOLDER = "HCP"  # HCP / Schizo
    LABELS_FOLDER = "bundle_masks"  # bundle_masks / bundle_masks_dm
    MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME)
    EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME)  # default path
    BATCH_SIZE = 47  #30/44  #max: #Peak Prediction: 44 #Pytorch: 50  #Lasagne: 56  #Lasagne combined: 42  #Pytorch UpSample: 56   #Pytorch_SE_r16: 45    #Pytorch_SE_r64: 45
    LEARNING_RATE = 0.001  # 0.002 #LR find: 0.000143 ?  # 0.001
    LR_SCHEDULE = False
    UNET_NR_FILT = 64
    LOAD_WEIGHTS = False
    # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz")    # Can be absolute path or relative like "exp_folder/weights.npz"
    WEIGHTS_PATH = ""  # if empty string: autoloading the best_weights in get_best_weights_path()
    TYPE = "single_direction"  # single_direction / combined
    CV_FOLD = 0
    VALIDATE_SUBJECTS = []
    TRAIN_SUBJECTS = []
    TEST_SUBJECTS = []
    TRAIN = True
    TEST = True
    SEGMENT = False
    GET_PROBS = False
    OUTPUT_MULTIPLE_FILES = False
    RESET_LAST_LAYER = False

    # Peak_regression specific
    PEAK_DICE_THR = [0.95]
    PEAK_DICE_LEN_THR = 0.05
    FLIP_OUTPUT_PEAKS = True  # flip peaks along z axis to make them compatible with MITK

    # For TractSeg.py application
    PREDICT_IMG = False
    PREDICT_IMG_OUTPUT = None
    TRACTSEG_DIR = "tractseg_output"
    KEEP_INTERMEDIATE_FILES = False
    CSD_RESOLUTION = "LOW"  # HIGH / LOW

    #Unimportant / rarly changed:
    LABELS_TYPE = np.int16  # Binary: np.int16, Regression: np.float32
    THRESHOLD = 0.5  # Binary: 0.5, Regression: 0.01 ?
    TEST_TIME_DAUG = False
    USE_VISLOGGER = False  #only works with Python 3
    SAVE_WEIGHTS = True
    SEG_INPUT = "Peaks"  # Gradients/ Peaks
    NR_SLICES = 1  # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction()
    PRINT_FREQ = 20  #20
    NORMALIZE_DATA = True
    NORMALIZE_PER_CHANNEL = False
    BEST_EPOCH = 0
    VERBOSE = True
    CALC_F1 = True
コード例 #19
0
ファイル: MetricUtils.py プロジェクト: silongGG/TractSeg
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
コード例 #20
0
 def test_bundle_names(self):
     bundles = ExpUtils.get_bundle_names("CST_right")
     self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
コード例 #21
0
ファイル: ImgUtils.py プロジェクト: doctoryfx/TractSeg
 def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path):
     bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
     for idx, bundle in enumerate(bundles):
         img_seg = nib.Nifti1Image(img[:,:,:,idx], affine)
         ExpUtils.make_dir(join(path, "endings_segmentations"))
         nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
コード例 #22
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
コード例 #23
0
ファイル: TractSeg.py プロジェクト: doctoryfx/TractSeg
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg