def soft_dice_paul(HP, idxs, marker, preds, ys): n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES)) dice = T.constant(0) for cl in range(n_classes): pred = preds[marker, cl, :, :] y = ys[marker, cl, :, :] intersect = T.sum(pred * y) denominator = T.sum(pred) + T.sum(y) dice += T.constant(2) * intersect / (denominator + T.constant(1e-6)) return 1 - (dice / n_classes)
def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' Ca :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = PytorchUtils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def theano_f1_score_OLD(HP, idxs, marker, preds, ys): ''' Von Paul ''' n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES)) dice = T.constant(0) for cl in range(n_classes): pred = preds[marker, cl, :, :] y = ys[marker, cl, :, :] pred = T.gt(pred, T.constant(0.5)) intersect = T.sum(pred * y) denominator = T.sum(pred) + T.sum(y) dice += T.constant(2) * intersect / (denominator + T.constant(1e-6)) return dice / n_classes
def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx*3):(idx*3)+3] if HP.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) ExpUtils.make_dir(join(path, "TOM")) nib.save(img_seg, join(path, "TOM", filename))
def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 3):(idx * 3) + 3] if HP.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) ExpUtils.make_dir(join(path, "TOM")) nib.save(img_seg, join(path, "TOM", filename))
def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle(a, b): ''' Calculate the angle between two 1d-arrays (2 vectors) along the last dimension without anything further: 1->0°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) ''' return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs( np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_length_dice(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1) lengths_true = np.linalg.norm(y_true_bund, axis=-1) lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.flatten() gt_binary = y_true_bund.sum(axis=-1) > 0 gt_binary = gt_binary.flatten() # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.flatten() combined = lengths_binary * angles_binary f1 = MetricUtils.my_f1_score(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle(a, b): ''' Calculate the angle between two 1d-arrays (2 vectors) along the last dimension without anything further: 1->0°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) ''' return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"): ''' One-hot encoding of all bundles in one big image :param subject: :return: image of shape (x, y, z, nr_of_bundles + 1) ''' bundles = ExpUtils.get_bundle_names(HP.CLASSES) #Masks sind immer HCP_highRes (später erst downsample) mask_ml = np.zeros((145, 174, 145, len(bundles))) background = np.ones((145, 174, 145)) # everything that contains no bundle for idx, bundle in enumerate(bundles[1:]): #first bundle is background -> already considered by setting np.ones in the beginning mask = nib.load(join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz")) mask_data = mask.get_data() # dtype: uint8 mask_ml[:, :, :, idx+1] = mask_data background[mask_data == 1] = 0 # remove this bundle from background mask_ml[:, :, :, 0] = background return mask_ml.astype(labels_type)
def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True): ''' multilabel True: save as 1 and 2 without fourth dimension multilabel False: save with beginnings and endings combined ''' # bundles = ExpUtils.get_bundle_names("20")[1:] bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0 multilabel_img = np.zeros(data.shape[:3]) if multilabel: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 2 else: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 1 img_seg = nib.Nifti1Image(multilabel_img, affine) ExpUtils.make_dir(join(path, "endings")) nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
def calc_peak_dice_onlySeg(HP, y_pred, y_true): ''' Create binary mask of peaks by simple thresholding. Then calculate Dice. :param y_pred: :param y_true: :return: ''' score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST) # Resulting dice for 0.1 and 0.2 very similar y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2 y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3 f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' Ca :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs( einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred - lengths_true) < ( max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = PytorchUtils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]): ''' Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:,:,:,idx], affine) ExpUtils.make_dir(join(path, "endings_segmentations")) nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
def plot_tracts(HP, bundle_segmentations, affine, out_dir, brain_mask=None): ''' By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL. On the remote Server you can do 'export DISPLAY=":0"' (you should set the value you get if you do 'echo $DISPLAY' if you login locally on the remote server). Then all graphics will get rendered locally and not via -X. (important: graphical session needs to be running on remote server (e.g. via login locally)) (important: login needed, not just stay at login screen) ''' from dipy.viz import window from tractseg.libs.VtkUtils import VtkUtils SMOOTHING = 10 WINDOW_SIZE = (800, 800) bundles = ["CST_right", "CA", "IFO_right"] renderer = window.Renderer() renderer.projection('parallel') rows = len(bundles) X, Y, Z = bundle_segmentations.shape[:3] for j, bundle in enumerate(bundles): i = 0 #only one method bundle_idx = ExpUtils.get_bundle_names( HP.CLASSES)[1:].index(bundle) mask_data = bundle_segmentations[:, :, :, bundle_idx] if bundle == "CST_right": orientation = "axial" elif bundle == "CA": orientation = "axial" elif bundle == "IFO_right": orientation = "sagittal" else: orientation = "axial" #bigger: more border if orientation == "axial": border_y = -100 #-60 else: border_y = -100 x_current = X * i # column (width) y_current = rows * (Y * 2 + border_y) - ( Y * 2 + border_y) * j # row (height) (starts from bottom?) PlotUtils.plot_mask(renderer, mask_data, affine, x_current, y_current, orientation=orientation, smoothing=SMOOTHING, brain_mask=brain_mask) #Bundle label text_offset_top = -50 # 60 text_offset_side = -100 # -30 position = (0 - int(X) + text_offset_side, y_current + text_offset_top, 50) text_actor = VtkUtils.label(text=bundle, pos=position, scale=(6, 6, 6), color=(1, 1, 1)) renderer.add(text_actor) renderer.reset_camera() window.record(renderer, out_path=join(out_dir, "preview.png"), size=(WINDOW_SIZE[0], WINDOW_SIZE[1]), reset_camera=False, magnification=2)
class HP: EXP_MULTI_NAME = "" #CV Parent Dir name # leave empty for Single Bundle Experiment EXP_NAME = "HCP_TEST" # HCP_TEST MODEL = "UNet_Pytorch" # UNet_Lasagne / UNet_Pytorch EXPERIMENT_TYPE = "tract_segmentation" # tract_segmentation / endings_segmentation / dm_regression / peak_regression NUM_EPOCHS = 250 DATA_AUGMENTATION = False DAUG_SCALE = True DAUG_NOISE = True DAUG_ELASTIC_DEFORM = True DAUG_RESAMPLE = True DAUG_ROTATE = False DAUG_MIRROR = False DAUG_FLIP_PEAKS = False DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)" DATASET = "HCP" # HCP / HCP_32g / Schizo RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm) FEATURES_FILENAME = "12g90g270g" # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz LABELS_FILENAME = "" # autofilled #"bundle_peaks/CA" #IMPORTANT: Adapt BatchGen if 808080 # bundle_masks / bundle_masks_72 / bundle_masks_dm / bundle_peaks #Only used when using DataManagerNifti LOSS_FUNCTION = "default" # default / soft_batch_dice OPTIMIZER = "Adamax" CLASSES = "All" # All / 11 / 20 / CST_right NR_OF_GRADIENTS = 9 NR_OF_CLASSES = len(ExpUtils.get_bundle_names(CLASSES)[1:]) # NR_OF_CLASSES = 3 * len(ExpUtils.get_bundle_names(CLASSES)[1:]) INPUT_DIM = (144, 144) # (80, 80) / (144, 144) LOSS_WEIGHT = 1 # 1: no weighting LOSS_WEIGHT_LEN = -1 # -1: constant over all epochs SLICE_DIRECTION = "y" # x, y, z (combined needs z) TRAINING_SLICE_DIRECTION = "xyz" # y / xyz INFO = "-" # Dropout, Deconv, 11bundles, LeakyRelu, PeakDiceThres=0.9 BATCH_NORM = False WEIGHT_DECAY = 0 USE_DROPOUT = False DROPOUT_SAMPLING = False # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset" # HCP / HCP_batches/XXX / TRACED / HCP_fusion_npy_270g_125mm / HCP_fusion_npy_32g_25mm # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_XYZ" DATASET_FOLDER = "HCP" # HCP / Schizo LABELS_FOLDER = "bundle_masks" # bundle_masks / bundle_masks_dm MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME) EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME) # default path BATCH_SIZE = 47 #30/44 #max: #Peak Prediction: 44 #Pytorch: 50 #Lasagne: 56 #Lasagne combined: 42 #Pytorch UpSample: 56 #Pytorch_SE_r16: 45 #Pytorch_SE_r64: 45 LEARNING_RATE = 0.001 # 0.002 #LR find: 0.000143 ? # 0.001 LR_SCHEDULE = False UNET_NR_FILT = 64 LOAD_WEIGHTS = False # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz") # Can be absolute path or relative like "exp_folder/weights.npz" WEIGHTS_PATH = "" # if empty string: autoloading the best_weights in get_best_weights_path() TYPE = "single_direction" # single_direction / combined CV_FOLD = 0 VALIDATE_SUBJECTS = [] TRAIN_SUBJECTS = [] TEST_SUBJECTS = [] TRAIN = True TEST = True SEGMENT = False GET_PROBS = False OUTPUT_MULTIPLE_FILES = False RESET_LAST_LAYER = False # Peak_regression specific PEAK_DICE_THR = [0.95] PEAK_DICE_LEN_THR = 0.05 FLIP_OUTPUT_PEAKS = True # flip peaks along z axis to make them compatible with MITK # For TractSeg.py application PREDICT_IMG = False PREDICT_IMG_OUTPUT = None TRACTSEG_DIR = "tractseg_output" KEEP_INTERMEDIATE_FILES = False CSD_RESOLUTION = "LOW" # HIGH / LOW #Unimportant / rarly changed: LABELS_TYPE = np.int16 # Binary: np.int16, Regression: np.float32 THRESHOLD = 0.5 # Binary: 0.5, Regression: 0.01 ? TEST_TIME_DAUG = False USE_VISLOGGER = False #only works with Python 3 SAVE_WEIGHTS = True SEG_INPUT = "Peaks" # Gradients/ Peaks NR_SLICES = 1 # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction() PRINT_FREQ = 20 #20 NORMALIZE_DATA = True NORMALIZE_PER_CHANNEL = False BEST_EPOCH = 0 VERBOSE = True CALC_F1 = True
def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]): ''' Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs( einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def test_bundle_names(self): bundles = ExpUtils.get_bundle_names("CST_right") self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3) # set lower for more sensitivity #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :param peak_threshold: all peaks shorter than peak_threshold will be set to zero :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg