示例#1
0
def test_whole_subject(Config, model, subjects, type):

    metrics = {
        "loss_" + type: [0],
        "f1_macro_" + type: [0],
    }

    metrics_bundles = defaultdict(lambda: [0])

    for subject in subjects:
        print("{} subject {}".format(type, subject))
        start_time = time.time()

        data_loader = DataLoaderInference(Config, subject=subject)
        img_probs, img_y = predict_img(Config, model, data_loader, probs=True)
        # img_probs_xyz, img_y = DirectionMerger.get_seg_single_img_3_directions(Config, model, subject=subject)
        # img_probs = DirectionMerger.mean_fusion(Config.THRESHOLD, img_probs_xyz, probs=True)

        print("Took {}s".format(round(time.time() - start_time, 2)))

        if Config.EXPERIMENT_TYPE == "peak_regression":
            f1 = metric_utils.calc_peak_length_dice(Config, img_probs, img_y,
                                                    max_angle_error=Config.PEAK_DICE_THR,
                                                    max_length_error=Config.PEAK_DICE_LEN_THR)
            peak_f1_mean = np.array([s for s in f1.values()]).mean()  # if f1 for multiple bundles
            metrics = metric_utils.calculate_metrics(metrics, None, None, 0, f1=peak_f1_mean,
                                                     type=type, threshold=Config.THRESHOLD)
            metrics_bundles = metric_utils.calculate_metrics_each_bundle(metrics_bundles, None, None,
                                                                         dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
                                                                         f1, threshold=Config.THRESHOLD)
        else:
            img_probs = np.reshape(img_probs, (-1, img_probs.shape[-1]))  # Flatten all dims except nr_classes dim
            img_y = np.reshape(img_y, (-1, img_y.shape[-1]))
            metrics = metric_utils.calculate_metrics(metrics, img_y, img_probs, 0,
                                                     type=type, threshold=Config.THRESHOLD)
            metrics_bundles = metric_utils.calculate_metrics_each_bundle(metrics_bundles, img_y, img_probs,
                                                                         dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
                                                                         threshold=Config.THRESHOLD)

    metrics = metric_utils.normalize_last_element(metrics, len(subjects), type=type)
    metrics_bundles = metric_utils.normalize_last_element_general(metrics_bundles, len(subjects))

    print("WHOLE SUBJECT:")
    pprint(metrics)
    print("WHOLE SUBJECT BUNDLES:")
    pprint(metrics_bundles)

    with open(join(Config.EXP_PATH, "score_" + type + "-set.txt"), "w") as f:
        pprint(metrics, f)
        f.write("\n\nWeights: {}\n".format(Config.WEIGHTS_PATH))
        f.write("type: {}\n\n".format(type))
        pprint(metrics_bundles, f)
    pickle.dump(metrics, open(join(Config.EXP_PATH, "score_" + type + ".pkl"), "wb"))
    return metrics
示例#2
0
 def test_tractseg_output_docker(self):
     bundles = dataset_specific_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         img_ref = nib.load("tests/reference_files/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8)
         img_new = nib.load("examples/docker_test/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8)
         images_equal = np.array_equal(img_ref, img_new)
         self.assertTrue(images_equal, "Docker tract segmentations are not correct (bundle: " + bundle + ")")
示例#3
0
def save_multilabel_img_as_multiple_files_endings(classes, img, affine, path):
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        img_seg = nib.Nifti1Image(img[:, :, :, idx], affine)
        exp_utils.make_dir(join(path, "endings_segmentations"))
        nib.save(img_seg,
                 join(path, "endings_segmentations", bundle + ".nii.gz"))
示例#4
0
def calc_peak_length_dice(classes,
                          y_pred,
                          y_true,
                          max_angle_error=[0.9],
                          max_length_error=0.1):
    score_per_bundle = {}
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x, y, z, 3]

        angles = abs(peak_utils.angle_last_dim(y_pred_bund, y_true_bund))

        lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1)
        lengths_true = np.linalg.norm(y_true_bund, axis=-1)
        lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error *
                                                             lengths_true)
        lengths_binary = lengths_binary.flatten()

        gt_binary = y_true_bund.sum(axis=-1) > 0
        gt_binary = gt_binary.flatten()  # [bs*x*y]

        angles_binary = angles > max_angle_error[0]
        angles_binary = angles_binary.flatten()

        combined = lengths_binary * angles_binary

        f1 = my_f1_score(gt_binary, combined)
        score_per_bundle[bundle] = f1
    return score_per_bundle
示例#5
0
def create_multilabel_mask(classes,
                           subject,
                           labels_type=np.int16,
                           dataset_folder="HCP",
                           labels_folder="bundle_masks"):
    """
    One-hot encoding of all bundles in one big image
    """
    bundles = dataset_specific_utils.get_bundle_names(classes)

    #Masks sind immer HCP_highRes (später erst downsample)
    mask_ml = np.zeros((145, 174, 145, len(bundles)))
    background = np.ones((145, 174, 145))  # everything that contains no bundle

    # first bundle is background -> already considered by setting np.ones in the beginning
    for idx, bundle in enumerate(bundles[1:]):
        mask = nib.load(
            join(C.HOME, dataset_folder, subject, labels_folder,
                 bundle + ".nii.gz"))
        mask_data = mask.get_data()  # dtype: uint8
        mask_ml[:, :, :, idx + 1] = mask_data
        background[mask_data == 1] = 0  # remove this bundle from background

    mask_ml[:, :, :, 0] = background
    return mask_ml.astype(labels_type)
示例#6
0
 def test_peakreg_output(self):
     bundles = dataset_specific_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         img_ref = nib.load("tests/reference_files/TOM/" + bundle + ".nii.gz").get_fdata()
         img_new = nib.load("examples/tractseg_output/TOM/" + bundle + ".nii.gz").get_fdata()
         # Because of regression small tolerance margin needed
         images_equal = np.allclose(img_ref, img_new, rtol=1e-3, atol=1e-3)
         self.assertTrue(images_equal, "TOMs are not correct (bundle: " + bundle + ")")
示例#7
0
def save_multilabel_img_as_multiple_files(Config,
                                          img,
                                          affine,
                                          path,
                                          name="bundle_segmentations"):
    bundles = dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        img_seg = nib.Nifti1Image(img[:, :, :, idx], affine)
        exp_utils.make_dir(join(path, name))
        nib.save(img_seg, join(path, name, bundle + ".nii.gz"))
示例#8
0
    def test_endingsseg_output(self):
        bundles = dataset_specific_utils.get_bundle_names("All")[1:]
        for bundle in bundles:
            img_ref = nib.load("tests/reference_files/endings_segmentations/" + bundle + "_b.nii.gz").get_fdata().astype(np.uint8)
            img_new = nib.load("examples/tractseg_output/endings_segmentations/" + bundle + "_b.nii.gz").get_fdata().astype(np.uint8)
            images_equal = np.array_equal(img_ref, img_new)
            self.assertTrue(images_equal, "Bundle endings are not correct (bundle: " + bundle + "_b)")

            img_ref = nib.load("tests/reference_files/endings_segmentations/" + bundle + "_e.nii.gz").get_fdata().astype(np.uint8)
            img_new = nib.load("examples/tractseg_output/endings_segmentations/" + bundle + "_e.nii.gz").get_fdata().astype(np.uint8)
            images_equal = np.array_equal(img_ref, img_new)
            self.assertTrue(images_equal, "Bundle endings are not correct (bundle: " + bundle + "_e)")
示例#9
0
 def test_tractseg_output_SR_noPP(self):
     bundles = dataset_specific_utils.get_bundle_names("All")[1:]
     for bundle in bundles:
         # IFO very different on travis than locally. Unclear why. All other bundles are fine.
         if bundle != "IFO_right":
             img_ref = nib.load("tests/reference_files/bundle_segmentations_SR_noPP/" + bundle + ".nii.gz").get_fdata().astype(np.uint8)
             img_new = nib.load("examples/SR_noPP/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8)
             # Processing on travis slightly different from local environment -> have to allow for small margin
             nr_differing_voxels = np.abs(img_ref - img_new).sum()
             if nr_differing_voxels < 5:
                 images_equal = True
             else:
                 images_equal = False
             self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ") " +
                                           "(nr of differing voxels: " + str(nr_differing_voxels) + ")")
class Config(TractSegConfig):
    EXP_NAME = os.path.basename(__file__).split(".")[0]

    DATASET_FOLDER = "HCP_preproc_all"
    NR_OF_GRADIENTS = 18
    FEATURES_FILENAME = "32g270g_BX"
    P_SAMP = 0.4

    CLASSES = "AutoPTX_42"
    NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:])

    DATASET = "HCP_all"

    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"
    LR_SCHEDULE_PATIENCE = 20

    NUM_EPOCHS = 200  # 130 probably also fine
示例#11
0
def calc_peak_dice(classes, y_pred, y_true, max_angle_error=[0.9]):
    score_per_bundle = {}
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # (x,y,z,3)

        angles = abs(peak_utils.angle_last_dim(y_pred_bund, y_true_bund))
        angles_binary = angles > max_angle_error[0]

        gt_binary = y_true_bund.sum(axis=-1) > 0

        f1 = f1_score(gt_binary.flatten(),
                      angles_binary.flatten(),
                      average="binary")
        score_per_bundle[bundle] = f1

    return score_per_bundle
示例#12
0
def save_multilabel_img_as_multiple_files_peaks(Config,
                                                img,
                                                affine,
                                                path,
                                                name="TOM"):
    bundles = dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]
    for idx, bundle in enumerate(bundles):
        data = img[:, :, :, (idx * 3):(idx * 3) + 3]

        if Config.FLIP_OUTPUT_PEAKS:
            data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
            filename = bundle + "_f.nii.gz"
        else:
            filename = bundle + ".nii.gz"

        img_seg = nib.Nifti1Image(data, affine)
        exp_utils.make_dir(join(path, name))
        nib.save(img_seg, join(path, name, filename))
示例#13
0
class Config(TractSegConfig):
    EXP_NAME = os.path.basename(__file__).split(".")[0]

    DATASET = "HCP_all"
    DATASET_FOLDER = "HCP_preproc_all"
    FEATURES_FILENAME = "32g90g270g_CSD_BX"
    CLASSES = "xtract"
    NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:])
    RESOLUTION = "1.25mm"
    LABELS_FILENAME = "bundle_masks_xtract_thr001"

    NUM_EPOCHS = 300
    EPOCH_MULTIPLIER = 0.5

    DAUG_ROTATE = True
    SPATIAL_TRANSFORM = "SpatialTransformPeaks"
    # rotation: 2*np.pi = 360 degree  (-> 0.8 ~ 45 degree, 0.4 ~ 22 degree))
    DAUG_ROTATE_ANGLE = (-0.4, 0.4)
示例#14
0
def calc_peak_length_dice_pytorch(classes,
                                  y_pred,
                                  y_true,
                                  max_angle_error=[0.9],
                                  max_length_error=0.1):
    import torch
    from tractseg.libs import pytorch_utils

    if len(y_pred.shape) == 4:  # 2D
        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)
    else:  # 3D
        y_true = y_true.permute(0, 2, 3, 4, 1)
        y_pred = y_pred.permute(0, 2, 3, 4, 1)

    #Single threshold
    score_per_bundle = {}
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[..., (idx * 3):(idx * 3) + 3].contiguous()
        y_true_bund = y_true[..., (idx * 3):(idx * 3) +
                             3].contiguous()  # [x, y, z, 3]

        angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)

        lenghts_pred = torch.norm(y_pred_bund, 2., -1)
        lengths_true = torch.norm(y_true_bund, 2, -1)
        lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
            max_length_error * lengths_true)
        lengths_binary = lengths_binary.view(-1)

        gt_binary = y_true_bund.sum(dim=-1) > 0
        gt_binary = gt_binary.view(-1)  # [bs*x*y]

        angles_binary = angles > max_angle_error[0]
        angles_binary = angles_binary.view(-1)

        combined = lengths_binary * angles_binary

        f1 = pytorch_utils.f1_score_binary(gt_binary, combined)
        score_per_bundle[bundle] = f1
    return score_per_bundle
示例#15
0
class Config(DmRegConfig):
    EXP_NAME = os.path.basename(__file__).split(".")[0]

    DATASET_FOLDER = "HCP_preproc_all"
    NR_OF_GRADIENTS = 18
    FEATURES_FILENAME = "32g270g_BX"
    P_SAMP = 0.4

    CLASSES = "AutoPTX_42"
    NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:])

    # THRESHOLD = 0.001  # Final DM wil be thresholded at this value
    THRESHOLD = 0.0001  # use lower value so user has more choice

    DATASET = "HCP_all"

    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"
    LR_SCHEDULE_PATIENCE = 20

    NUM_EPOCHS = 200  # 130 probably also fine
示例#16
0
def calc_peak_dice_onlySeg(classes, y_pred, y_true):
    """
    Create binary mask of peaks by simple thresholding. Then calculate Dice.
    """
    score_per_bundle = {}
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
        y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

        # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST)
        #  Resulting dice for 0.1 and 0.2 very similar
        y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2
        y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3

        f1 = f1_score(y_true_binary.flatten(),
                      y_pred_binary.flatten(),
                      average="binary")
        score_per_bundle[bundle] = f1

    return score_per_bundle
示例#17
0
 def test_bundle_names(self):
     bundles = dataset_specific_utils.get_bundle_names("CST_right")
     self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
示例#18
0
def calc_peak_dice_pytorch(classes, y_pred, y_true, max_angle_error=[0.9]):
    """
    Calculate angle between groundtruth and prediction and keep the voxels where
    angle is smaller than MAX_ANGLE_ERROR.

    From groundtruth generate a binary mask by selecting all voxels with len > 0.

    Calculate Dice from these 2 masks.

    -> Penalty on peaks outside of tract or if predicted peak=0
    -> no penalty on very very small with right direction -> bad
    => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

    Args:
        y_pred:
        y_true:
        max_angle_error: 0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                         Can be list with several values -> calculate for several thresholds

    Returns:

    """
    from tractseg.libs import pytorch_utils

    y_true = y_true.permute(0, 2, 3, 1)
    y_pred = y_pred.permute(0, 2, 3, 1)

    #Single threshold
    if len(max_angle_error) == 1:
        score_per_bundle = {}
        bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x, y, z, 3]

            angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
            score_per_bundle[bundle] = f1

        return score_per_bundle

    #multiple thresholds
    else:
        score_per_bundle = {}
        bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x, y, z, 3]

            angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            score_per_bundle[bundle] = []
            for threshold in max_angle_error:
                angles_binary = angles > threshold
                angles_binary = angles_binary.view(-1)

                f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle].append(f1)

        return score_per_bundle
示例#19
0
def plot_tracts(classes,
                bundle_segmentations,
                affine,
                out_dir,
                brain_mask=None):
    """
    By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL.
    On the remote Server you can do 'export DISPLAY=":0"' .
    (you should set the value you get if you do 'echo $DISPLAY' if you
    login locally on the remote server). Then all graphics will get rendered locally and not via -X.
    (important: graphical session needs to be running on remote server (e.g. via login locally))
    (important: login needed, not just stay at login screen)

    If running on a headless server without Display using Xvfb might help:
    https://stackoverflow.com/questions/6281998/can-i-run-glu-opengl-on-a-headless-server
    """
    from dipy.viz import window
    from tractseg.libs import vtk_utils

    SMOOTHING = 10
    WINDOW_SIZE = (800, 800)
    bundles = ["CST_right", "CA", "IFO_right"]

    renderer = window.Renderer()
    renderer.projection('parallel')

    rows = len(bundles)
    X, Y, Z = bundle_segmentations.shape[:3]
    for j, bundle in enumerate(bundles):
        i = 0  #only one method

        bundle_idx = dataset_specific_utils.get_bundle_names(
            classes)[1:].index(bundle)
        mask_data = bundle_segmentations[:, :, :, bundle_idx]

        if bundle == "CST_right":
            orientation = "axial"
        elif bundle == "CA":
            orientation = "axial"
        elif bundle == "IFO_right":
            orientation = "sagittal"
        else:
            orientation = "axial"

        #bigger: more border
        if orientation == "axial":
            border_y = -100
        else:
            border_y = -100

        x_current = X * i  # column (width)
        y_current = rows * (Y * 2 + border_y) - (
            Y * 2 + border_y) * j  # row (height)  (starts from bottom)

        plot_mask(renderer,
                  mask_data,
                  affine,
                  x_current,
                  y_current,
                  orientation=orientation,
                  smoothing=SMOOTHING,
                  brain_mask=brain_mask)

        #Bundle label
        text_offset_top = -50
        text_offset_side = -100
        position = (0 - int(X) + text_offset_side, y_current + text_offset_top,
                    50)
        text_actor = vtk_utils.label(text=bundle,
                                     pos=position,
                                     scale=(6, 6, 6),
                                     color=(1, 1, 1))
        renderer.add(text_actor)

    renderer.reset_camera()
    window.record(renderer,
                  out_path=join(out_dir, "preview.png"),
                  size=(WINDOW_SIZE[0], WINDOW_SIZE[1]),
                  reset_camera=False,
                  magnification=2)
示例#20
0
文件: base.py 项目: wj202007/TractSeg
class Config:
    """
    Settings and hyperparameters
    """

    # input data
    EXPERIMENT_TYPE = "tract_segmentation"  # tract_segmentation|endings_segmentation|dm_regression|peak_regression
    EXP_NAME = "HCP_TEST"
    EXP_MULTI_NAME = ""  # CV parent directory name; leave empty for single bundle experiment
    DATASET_FOLDER = "HCP_preproc"
    LABELS_FOLDER = "bundle_masks"
    MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME)
    EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME)  # default path
    CLASSES = "All"
    NR_OF_GRADIENTS = 9
    NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:])
    INPUT_DIM = None  # autofilled
    DATASET = "HCP"  # HCP | HCP_32g | Schizo
    RESOLUTION = "1.25mm"  # 1.25mm|2.5mm
    # 12g90g270g | 270g_125mm_xyz | 270g_125mm_peaks | 90g_125mm_peaks | 32g_25mm_peaks | 32g_25mm_xyz
    FEATURES_FILENAME = "12g90g270g"
    LABELS_FILENAME = ""  # autofilled
    LABELS_TYPE = "int"
    THRESHOLD = 0.5  # Binary: 0.5, Regression: 0.01

    # hyperparameters
    MODEL = "UNet_Pytorch_DeepSup"
    DIM = "2D"  # 2D | 3D
    BATCH_SIZE = 47
    LEARNING_RATE = 0.001
    LR_SCHEDULE = True
    LR_SCHEDULE_MODE = "min"  # min | max
    LR_SCHEDULE_PATIENCE = 20
    UNET_NR_FILT = 64
    EPOCH_MULTIPLIER = 1  # 2D: 1, 3D: 12 for lowRes, 3 for highRes
    NUM_EPOCHS = 250
    SLICE_DIRECTION = "y"  # x | y | z  ("combined" needs z)
    TRAINING_SLICE_DIRECTION = "xyz"  # y | xyz
    LOSS_FUNCTION = "default"  # default | soft_batch_dice
    OPTIMIZER = "Adamax"
    LOSS_WEIGHT = None  # None = no weighting
    LOSS_WEIGHT_LEN = -1  # -1 = constant over all epochs
    BATCH_NORM = False
    WEIGHT_DECAY = 0
    USE_DROPOUT = False
    DROPOUT_SAMPLING = False
    LOAD_WEIGHTS = False
    # WEIGHTS_PATH = join(C.EXP_PATH, "My_experiment/best_weights_ep64.npz")
    WEIGHTS_PATH = ""  # if empty string: autoloading the best_weights in get_best_weights_path()
    SAVE_WEIGHTS = True
    TYPE = "single_direction"  # single_direction | combined
    CV_FOLD = 0
    VALIDATE_SUBJECTS = []
    TRAIN_SUBJECTS = []
    TEST_SUBJECTS = []
    TRAIN = True
    TEST = True
    SEGMENT = False
    GET_PROBS = False
    OUTPUT_MULTIPLE_FILES = False
    RESET_LAST_LAYER = False
    UPSAMPLE_TYPE = "bilinear"  # bilinear | nearest
    BEST_EPOCH_SELECTION = "f1"  # f1 | loss
    METRIC_TYPES = ["loss", "f1_macro"]
    FP16 = True
    PEAK_DICE_THR = [0.95]
    PEAK_DICE_LEN_THR = 0.05
    FLIP_OUTPUT_PEAKS = False  # flip peaks along z axis to make them compatible with MITK
    USE_VISLOGGER = False
    SEG_INPUT = "Peaks"  # Gradients | Peaks
    NR_SLICES = 1
    PRINT_FREQ = 20
    NORMALIZE_DATA = True
    NORMALIZE_PER_CHANNEL = False
    BEST_EPOCH = 0
    VERBOSE = True
    CALC_F1 = True
    ONLY_VAL = False
    TEST_TIME_DAUG = False
    PAD_TO_SQUARE = True
    INPUT_RESCALING = False  # Resample data to different resolution (instead of doing in preprocessing))

    # data augmentation
    DATA_AUGMENTATION = True
    DAUG_SCALE = True
    DAUG_NOISE = True
    DAUG_NOISE_VARIANCE = (0, 0.05)
    DAUG_ELASTIC_DEFORM = True
    DAUG_ALPHA = (90., 120.)
    DAUG_SIGMA = (9., 11.)
    DAUG_RESAMPLE = False  # does not improve validation dice (if using Gaussian_blur) -> deactivate
    DAUG_RESAMPLE_LEGACY = False  # does not improve validation dice (at least on AutoPTX) -> deactivate
    DAUG_GAUSSIAN_BLUR = True
    DAUG_BLUR_SIGMA = (0, 1)
    DAUG_ROTATE = False
    DAUG_ROTATE_ANGLE = (
        -0.2, 0.2
    )  # rotation: 2*np.pi = 360 degree  (0.4 ~= 22 degree, 0.2 ~= 11 degree))
    DAUG_MIRROR = False
    DAUG_FLIP_PEAKS = False
    SPATIAL_TRANSFORM = "SpatialTransform"  # SpatialTransform|SpatialTransformPeaks
    P_SAMP = 1.0
    DAUG_INFO = "-"
    INFO = "-"

    # for inference
    PREDICT_IMG = False
    PREDICT_IMG_OUTPUT = None
    TRACTSEG_DIR = "tractseg_output"
    KEEP_INTERMEDIATE_FILES = False
    CSD_RESOLUTION = "LOW"  # HIGH | LOW
    NR_CPUS = -1
示例#21
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_postprocessing=True,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False,
                 tract_segmentations_path=None,
                 TOM_dilation=1,
                 unit_test=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX
        tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks
            outside of the segmentation mask)
        TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs.

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH,
                 exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                 "Hyperparameters.txt"))

    # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary)
    if get_probs:
        bundle_specific_postprocessing = False
        postprocess = False

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config)
    Config.RESET_LAST_LAYER = False

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v4.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v2.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # xtract
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_xtract_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_xtract_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition xtract not supported in combination with this output type"
                )

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)

    #runtime on HCP data: 0.9s
    data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data)
    # runtime on HCP data: 0.5s
    data, transformation = data_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING,
            tract_definition=tract_definition)
        model = BaseModel(Config, inference=True)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size,
                                             unit_test=unit_test)
            else:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=False,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
        else:
            seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v2.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v2.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v2.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v2.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part,
                tract_definition=tract_definition)
            model = BaseModel(Config, inference=True)

            if single_orientation:
                data_loder_inference = DataLoaderInference(Config, data=data)
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
            else:
                # 3 dir for Peaks -> bad results
                seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                    Config,
                    model,
                    data=data,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
                seg = direction_merger.mean_fusion_peaks(seg_xyz,
                                                         nr_cpus=nr_cpus)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling:
        # Runtime ~4s
        seg = img_utils.bundle_specific_postprocessing(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

    # runtime on HCP data: 5.1s
    seg = data_utils.cut_and_scale_img_back_to_original_img(seg,
                                                            transformation,
                                                            nr_cpus=nr_cpus)
    # runtime on HCP data: 1.6s
    seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape,
                                                     Config.NR_OF_CLASSES)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        seg = peak_utils.mask_and_normalize_peaks(
            seg,
            tract_segmentations_path,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            TOM_dilation,
            nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling:
        # Runtime ~7s for 1.25mm resolution
        # Runtime ~1.5s for  2mm resolution
        st = time.time()
        seg = img_utils.postprocess_segmentations(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            blob_thr=blob_size_thr,
            hole_closing=None)

    exp_utils.print_verbose(
        Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg
示例#22
0
def plot_tracts_matplotlib(classes,
                           bundle_segmentations,
                           background_img,
                           out_dir,
                           threshold=0.001,
                           exp_type="tract_segmentation"):
    def plot_single_tract(bg, data, orientation, bundle, exp_type):
        if orientation == "coronal":
            data = data.transpose(
                2, 0, 1,
                3) if exp_type == "peak_regression" else data.transpose(
                    2, 0, 1)
            data = data[::-1, :, :]
            bg = bg.transpose(2, 0, 1)[::-1, :, :]
        elif orientation == "sagittal":
            data = data.transpose(
                2, 1, 0,
                3) if exp_type == "peak_regression" else data.transpose(
                    2, 1, 0)
            data = data[::-1, :, :]
            bg = bg.transpose(2, 1, 0)[::-1, :, :]
        else:  # axial
            pass

        mask_voxel_coords = np.where(data != 0)
        if len(mask_voxel_coords) > 2 and len(mask_voxel_coords[2]) > 0:
            minidx = int(np.min(mask_voxel_coords[2]))
            maxidx = int(np.max(mask_voxel_coords[2])) + 1
            mean_slice = int(np.mean([minidx, maxidx]))
        else:
            mean_slice = int(bg.shape[2] / 2)
        bg = bg[:, :, mean_slice]
        # bg = matplotlib.colors.Normalize()(bg)

        # project 3D to 2D image
        if aggregation == "mean":
            data = data.mean(axis=2)
        else:
            data = data.max(axis=2)

        plt.imshow(bg, cmap="gray")
        data = np.ma.masked_where(data < 0.00001, data)
        plt.imshow(data,
                   cmap="autumn")  # even with cmap=autumn peaks still RGB
        plt.title(bundle, fontsize=7)

    if classes.startswith("xtract"):
        bundles = ["cst_r", "cst_s_r", "ifo_r", "fx_l", "fx_r", "or_l", "fma"]
    else:
        if exp_type == "endings_segmentation":
            bundles = [
                "CST_right_b", "CST_right_e", "CST_s_right_b", "CST_s_right_e",
                "CA_b", "CA_e"
            ]
        else:
            bundles = [
                "CST_right", "CST_s_right", "CA", "IFO_right", "FX_left",
                "FX_right", "OR_left", "CC_1"
            ]

    if exp_type == "peak_regression":
        s = bundle_segmentations.shape
        bundle_segmentations = bundle_segmentations.reshape(
            [s[0], s[1], s[2], int(s[3] / 3), 3])
        bundles = ["CST_right", "CST_s_right", "CA", "CC_1",
                   "AF_left"]  # can only use bundles from part1

    aggregation = "max"
    cols = 4
    rows = math.ceil(len(bundles) / cols)

    background_img = background_img[..., 0]

    for j, bundle in enumerate(bundles):
        bun = bundle.lower()
        if bun.startswith("ca") or bun.startswith("fx_") or bun.startswith("or_") or \
                bun.startswith("cc_1") or bun.startswith("fma"):
            orientation = "axial"
        elif bun.startswith("ifo_") or bun.startswith("icp_") or bun.startswith("cst_s_") or \
                bun.startswith("af_"):
            bundle = bundle.replace("_s", "")
            orientation = "sagittal"
        elif bun.startswith("cst_"):
            orientation = "coronal"
        else:
            raise ValueError("invalid bundle")

        bundle_idx = dataset_specific_utils.get_bundle_names(
            classes)[1:].index(bundle)
        mask_data = bundle_segmentations[:, :, :, bundle_idx]
        mask_data = np.copy(
            mask_data
        )  # copy data otherwise will also threshold data outside of plot function
        # mask_data[mask_data < threshold] = 0
        mask_data[
            mask_data <
            0.001] = 0  # higher value better for preview, otherwise half of image just red

        plt.subplot(rows, cols, j + 1)
        plt.axis("off")
        plot_single_tract(background_img,
                          mask_data,
                          orientation,
                          bundle,
                          exp_type=exp_type)

    if exp_type == "tract_segmentation":
        file_name = "preview_bundle"
    elif exp_type == "endings_segmentation":
        file_name = "preview_endings"
    elif exp_type == "peak_regression":
        file_name = "preview_TOM"
    elif exp_type == "dm_regression":
        file_name = "preview_dm"
    else:
        file_name = "preview"

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig(join(out_dir, file_name + ".png"),
                bbox_inches='tight',
                dpi=300)