def test_whole_subject(Config, model, subjects, type): metrics = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics_bundles = defaultdict(lambda: [0]) for subject in subjects: print("{} subject {}".format(type, subject)) start_time = time.time() data_loader = DataLoaderInference(Config, subject=subject) img_probs, img_y = predict_img(Config, model, data_loader, probs=True) # img_probs_xyz, img_y = DirectionMerger.get_seg_single_img_3_directions(Config, model, subject=subject) # img_probs = DirectionMerger.mean_fusion(Config.THRESHOLD, img_probs_xyz, probs=True) print("Took {}s".format(round(time.time() - start_time, 2))) if Config.EXPERIMENT_TYPE == "peak_regression": f1 = metric_utils.calc_peak_length_dice(Config, img_probs, img_y, max_angle_error=Config.PEAK_DICE_THR, max_length_error=Config.PEAK_DICE_LEN_THR) peak_f1_mean = np.array([s for s in f1.values()]).mean() # if f1 for multiple bundles metrics = metric_utils.calculate_metrics(metrics, None, None, 0, f1=peak_f1_mean, type=type, threshold=Config.THRESHOLD) metrics_bundles = metric_utils.calculate_metrics_each_bundle(metrics_bundles, None, None, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], f1, threshold=Config.THRESHOLD) else: img_probs = np.reshape(img_probs, (-1, img_probs.shape[-1])) # Flatten all dims except nr_classes dim img_y = np.reshape(img_y, (-1, img_y.shape[-1])) metrics = metric_utils.calculate_metrics(metrics, img_y, img_probs, 0, type=type, threshold=Config.THRESHOLD) metrics_bundles = metric_utils.calculate_metrics_each_bundle(metrics_bundles, img_y, img_probs, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], threshold=Config.THRESHOLD) metrics = metric_utils.normalize_last_element(metrics, len(subjects), type=type) metrics_bundles = metric_utils.normalize_last_element_general(metrics_bundles, len(subjects)) print("WHOLE SUBJECT:") pprint(metrics) print("WHOLE SUBJECT BUNDLES:") pprint(metrics_bundles) with open(join(Config.EXP_PATH, "score_" + type + "-set.txt"), "w") as f: pprint(metrics, f) f.write("\n\nWeights: {}\n".format(Config.WEIGHTS_PATH)) f.write("type: {}\n\n".format(type)) pprint(metrics_bundles, f) pickle.dump(metrics, open(join(Config.EXP_PATH, "score_" + type + ".pkl"), "wb")) return metrics
def test_tractseg_output_docker(self): bundles = dataset_specific_utils.get_bundle_names("All")[1:] for bundle in bundles: img_ref = nib.load("tests/reference_files/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8) img_new = nib.load("examples/docker_test/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8) images_equal = np.array_equal(img_ref, img_new) self.assertTrue(images_equal, "Docker tract segmentations are not correct (bundle: " + bundle + ")")
def save_multilabel_img_as_multiple_files_endings(classes, img, affine, path): bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:, :, :, idx], affine) exp_utils.make_dir(join(path, "endings_segmentations")) nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
def calc_peak_length_dice(classes, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x, y, z, 3] angles = abs(peak_utils.angle_last_dim(y_pred_bund, y_true_bund)) lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1) lengths_true = np.linalg.norm(y_true_bund, axis=-1) lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.flatten() gt_binary = y_true_bund.sum(axis=-1) > 0 gt_binary = gt_binary.flatten() # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.flatten() combined = lengths_binary * angles_binary f1 = my_f1_score(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def create_multilabel_mask(classes, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"): """ One-hot encoding of all bundles in one big image """ bundles = dataset_specific_utils.get_bundle_names(classes) #Masks sind immer HCP_highRes (später erst downsample) mask_ml = np.zeros((145, 174, 145, len(bundles))) background = np.ones((145, 174, 145)) # everything that contains no bundle # first bundle is background -> already considered by setting np.ones in the beginning for idx, bundle in enumerate(bundles[1:]): mask = nib.load( join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz")) mask_data = mask.get_data() # dtype: uint8 mask_ml[:, :, :, idx + 1] = mask_data background[mask_data == 1] = 0 # remove this bundle from background mask_ml[:, :, :, 0] = background return mask_ml.astype(labels_type)
def test_peakreg_output(self): bundles = dataset_specific_utils.get_bundle_names("All")[1:] for bundle in bundles: img_ref = nib.load("tests/reference_files/TOM/" + bundle + ".nii.gz").get_fdata() img_new = nib.load("examples/tractseg_output/TOM/" + bundle + ".nii.gz").get_fdata() # Because of regression small tolerance margin needed images_equal = np.allclose(img_ref, img_new, rtol=1e-3, atol=1e-3) self.assertTrue(images_equal, "TOMs are not correct (bundle: " + bundle + ")")
def save_multilabel_img_as_multiple_files(Config, img, affine, path, name="bundle_segmentations"): bundles = dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:, :, :, idx], affine) exp_utils.make_dir(join(path, name)) nib.save(img_seg, join(path, name, bundle + ".nii.gz"))
def test_endingsseg_output(self): bundles = dataset_specific_utils.get_bundle_names("All")[1:] for bundle in bundles: img_ref = nib.load("tests/reference_files/endings_segmentations/" + bundle + "_b.nii.gz").get_fdata().astype(np.uint8) img_new = nib.load("examples/tractseg_output/endings_segmentations/" + bundle + "_b.nii.gz").get_fdata().astype(np.uint8) images_equal = np.array_equal(img_ref, img_new) self.assertTrue(images_equal, "Bundle endings are not correct (bundle: " + bundle + "_b)") img_ref = nib.load("tests/reference_files/endings_segmentations/" + bundle + "_e.nii.gz").get_fdata().astype(np.uint8) img_new = nib.load("examples/tractseg_output/endings_segmentations/" + bundle + "_e.nii.gz").get_fdata().astype(np.uint8) images_equal = np.array_equal(img_ref, img_new) self.assertTrue(images_equal, "Bundle endings are not correct (bundle: " + bundle + "_e)")
def test_tractseg_output_SR_noPP(self): bundles = dataset_specific_utils.get_bundle_names("All")[1:] for bundle in bundles: # IFO very different on travis than locally. Unclear why. All other bundles are fine. if bundle != "IFO_right": img_ref = nib.load("tests/reference_files/bundle_segmentations_SR_noPP/" + bundle + ".nii.gz").get_fdata().astype(np.uint8) img_new = nib.load("examples/SR_noPP/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_fdata().astype(np.uint8) # Processing on travis slightly different from local environment -> have to allow for small margin nr_differing_voxels = np.abs(img_ref - img_new).sum() if nr_differing_voxels < 5: images_equal = True else: images_equal = False self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ") " + "(nr of differing voxels: " + str(nr_differing_voxels) + ")")
class Config(TractSegConfig): EXP_NAME = os.path.basename(__file__).split(".")[0] DATASET_FOLDER = "HCP_preproc_all" NR_OF_GRADIENTS = 18 FEATURES_FILENAME = "32g270g_BX" P_SAMP = 0.4 CLASSES = "AutoPTX_42" NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:]) DATASET = "HCP_all" LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" LR_SCHEDULE_PATIENCE = 20 NUM_EPOCHS = 200 # 130 probably also fine
def calc_peak_dice(classes, y_pred, y_true, max_angle_error=[0.9]): score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # (x,y,z,3) angles = abs(peak_utils.angle_last_dim(y_pred_bund, y_true_bund)) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def save_multilabel_img_as_multiple_files_peaks(Config, img, affine, path, name="TOM"): bundles = dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 3):(idx * 3) + 3] if Config.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) exp_utils.make_dir(join(path, name)) nib.save(img_seg, join(path, name, filename))
class Config(TractSegConfig): EXP_NAME = os.path.basename(__file__).split(".")[0] DATASET = "HCP_all" DATASET_FOLDER = "HCP_preproc_all" FEATURES_FILENAME = "32g90g270g_CSD_BX" CLASSES = "xtract" NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:]) RESOLUTION = "1.25mm" LABELS_FILENAME = "bundle_masks_xtract_thr001" NUM_EPOCHS = 300 EPOCH_MULTIPLIER = 0.5 DAUG_ROTATE = True SPATIAL_TRANSFORM = "SpatialTransformPeaks" # rotation: 2*np.pi = 360 degree (-> 0.8 ~ 45 degree, 0.4 ~ 22 degree)) DAUG_ROTATE_ANGLE = (-0.4, 0.4)
def calc_peak_length_dice_pytorch(classes, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): import torch from tractseg.libs import pytorch_utils if len(y_pred.shape) == 4: # 2D y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) else: # 3D y_true = y_true.permute(0, 2, 3, 4, 1) y_pred = y_pred.permute(0, 2, 3, 4, 1) #Single threshold score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[..., (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[..., (idx * 3):(idx * 3) + 3].contiguous() # [x, y, z, 3] angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred - lengths_true) < ( max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = pytorch_utils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
class Config(DmRegConfig): EXP_NAME = os.path.basename(__file__).split(".")[0] DATASET_FOLDER = "HCP_preproc_all" NR_OF_GRADIENTS = 18 FEATURES_FILENAME = "32g270g_BX" P_SAMP = 0.4 CLASSES = "AutoPTX_42" NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:]) # THRESHOLD = 0.001 # Final DM wil be thresholded at this value THRESHOLD = 0.0001 # use lower value so user has more choice DATASET = "HCP_all" LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" LR_SCHEDULE_PATIENCE = 20 NUM_EPOCHS = 200 # 130 probably also fine
def calc_peak_dice_onlySeg(classes, y_pred, y_true): """ Create binary mask of peaks by simple thresholding. Then calculate Dice. """ score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST) # Resulting dice for 0.1 and 0.2 very similar y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2 y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3 f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def test_bundle_names(self): bundles = dataset_specific_utils.get_bundle_names("CST_right") self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
def calc_peak_dice_pytorch(classes, y_pred, y_true, max_angle_error=[0.9]): """ Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) Args: y_pred: y_true: max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds Returns: """ from tractseg.libs import pytorch_utils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x, y, z, 3] angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = dataset_specific_utils.get_bundle_names(classes)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x, y, z, 3] angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def plot_tracts(classes, bundle_segmentations, affine, out_dir, brain_mask=None): """ By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL. On the remote Server you can do 'export DISPLAY=":0"' . (you should set the value you get if you do 'echo $DISPLAY' if you login locally on the remote server). Then all graphics will get rendered locally and not via -X. (important: graphical session needs to be running on remote server (e.g. via login locally)) (important: login needed, not just stay at login screen) If running on a headless server without Display using Xvfb might help: https://stackoverflow.com/questions/6281998/can-i-run-glu-opengl-on-a-headless-server """ from dipy.viz import window from tractseg.libs import vtk_utils SMOOTHING = 10 WINDOW_SIZE = (800, 800) bundles = ["CST_right", "CA", "IFO_right"] renderer = window.Renderer() renderer.projection('parallel') rows = len(bundles) X, Y, Z = bundle_segmentations.shape[:3] for j, bundle in enumerate(bundles): i = 0 #only one method bundle_idx = dataset_specific_utils.get_bundle_names( classes)[1:].index(bundle) mask_data = bundle_segmentations[:, :, :, bundle_idx] if bundle == "CST_right": orientation = "axial" elif bundle == "CA": orientation = "axial" elif bundle == "IFO_right": orientation = "sagittal" else: orientation = "axial" #bigger: more border if orientation == "axial": border_y = -100 else: border_y = -100 x_current = X * i # column (width) y_current = rows * (Y * 2 + border_y) - ( Y * 2 + border_y) * j # row (height) (starts from bottom) plot_mask(renderer, mask_data, affine, x_current, y_current, orientation=orientation, smoothing=SMOOTHING, brain_mask=brain_mask) #Bundle label text_offset_top = -50 text_offset_side = -100 position = (0 - int(X) + text_offset_side, y_current + text_offset_top, 50) text_actor = vtk_utils.label(text=bundle, pos=position, scale=(6, 6, 6), color=(1, 1, 1)) renderer.add(text_actor) renderer.reset_camera() window.record(renderer, out_path=join(out_dir, "preview.png"), size=(WINDOW_SIZE[0], WINDOW_SIZE[1]), reset_camera=False, magnification=2)
class Config: """ Settings and hyperparameters """ # input data EXPERIMENT_TYPE = "tract_segmentation" # tract_segmentation|endings_segmentation|dm_regression|peak_regression EXP_NAME = "HCP_TEST" EXP_MULTI_NAME = "" # CV parent directory name; leave empty for single bundle experiment DATASET_FOLDER = "HCP_preproc" LABELS_FOLDER = "bundle_masks" MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME) EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME) # default path CLASSES = "All" NR_OF_GRADIENTS = 9 NR_OF_CLASSES = len(dataset_specific_utils.get_bundle_names(CLASSES)[1:]) INPUT_DIM = None # autofilled DATASET = "HCP" # HCP | HCP_32g | Schizo RESOLUTION = "1.25mm" # 1.25mm|2.5mm # 12g90g270g | 270g_125mm_xyz | 270g_125mm_peaks | 90g_125mm_peaks | 32g_25mm_peaks | 32g_25mm_xyz FEATURES_FILENAME = "12g90g270g" LABELS_FILENAME = "" # autofilled LABELS_TYPE = "int" THRESHOLD = 0.5 # Binary: 0.5, Regression: 0.01 # hyperparameters MODEL = "UNet_Pytorch_DeepSup" DIM = "2D" # 2D | 3D BATCH_SIZE = 47 LEARNING_RATE = 0.001 LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" # min | max LR_SCHEDULE_PATIENCE = 20 UNET_NR_FILT = 64 EPOCH_MULTIPLIER = 1 # 2D: 1, 3D: 12 for lowRes, 3 for highRes NUM_EPOCHS = 250 SLICE_DIRECTION = "y" # x | y | z ("combined" needs z) TRAINING_SLICE_DIRECTION = "xyz" # y | xyz LOSS_FUNCTION = "default" # default | soft_batch_dice OPTIMIZER = "Adamax" LOSS_WEIGHT = None # None = no weighting LOSS_WEIGHT_LEN = -1 # -1 = constant over all epochs BATCH_NORM = False WEIGHT_DECAY = 0 USE_DROPOUT = False DROPOUT_SAMPLING = False LOAD_WEIGHTS = False # WEIGHTS_PATH = join(C.EXP_PATH, "My_experiment/best_weights_ep64.npz") WEIGHTS_PATH = "" # if empty string: autoloading the best_weights in get_best_weights_path() SAVE_WEIGHTS = True TYPE = "single_direction" # single_direction | combined CV_FOLD = 0 VALIDATE_SUBJECTS = [] TRAIN_SUBJECTS = [] TEST_SUBJECTS = [] TRAIN = True TEST = True SEGMENT = False GET_PROBS = False OUTPUT_MULTIPLE_FILES = False RESET_LAST_LAYER = False UPSAMPLE_TYPE = "bilinear" # bilinear | nearest BEST_EPOCH_SELECTION = "f1" # f1 | loss METRIC_TYPES = ["loss", "f1_macro"] FP16 = True PEAK_DICE_THR = [0.95] PEAK_DICE_LEN_THR = 0.05 FLIP_OUTPUT_PEAKS = False # flip peaks along z axis to make them compatible with MITK USE_VISLOGGER = False SEG_INPUT = "Peaks" # Gradients | Peaks NR_SLICES = 1 PRINT_FREQ = 20 NORMALIZE_DATA = True NORMALIZE_PER_CHANNEL = False BEST_EPOCH = 0 VERBOSE = True CALC_F1 = True ONLY_VAL = False TEST_TIME_DAUG = False PAD_TO_SQUARE = True INPUT_RESCALING = False # Resample data to different resolution (instead of doing in preprocessing)) # data augmentation DATA_AUGMENTATION = True DAUG_SCALE = True DAUG_NOISE = True DAUG_NOISE_VARIANCE = (0, 0.05) DAUG_ELASTIC_DEFORM = True DAUG_ALPHA = (90., 120.) DAUG_SIGMA = (9., 11.) DAUG_RESAMPLE = False # does not improve validation dice (if using Gaussian_blur) -> deactivate DAUG_RESAMPLE_LEGACY = False # does not improve validation dice (at least on AutoPTX) -> deactivate DAUG_GAUSSIAN_BLUR = True DAUG_BLUR_SIGMA = (0, 1) DAUG_ROTATE = False DAUG_ROTATE_ANGLE = ( -0.2, 0.2 ) # rotation: 2*np.pi = 360 degree (0.4 ~= 22 degree, 0.2 ~= 11 degree)) DAUG_MIRROR = False DAUG_FLIP_PEAKS = False SPATIAL_TRANSFORM = "SpatialTransform" # SpatialTransform|SpatialTransformPeaks P_SAMP = 1.0 DAUG_INFO = "-" INFO = "-" # for inference PREDICT_IMG = False PREDICT_IMG_OUTPUT = None TRACTSEG_DIR = "tractseg_output" KEEP_INTERMEDIATE_FILES = False CSD_RESOLUTION = "LOW" # HIGH | LOW NR_CPUS = -1
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False, tract_segmentations_path=None, TOM_dilation=1, unit_test=False): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks outside of the segmentation mask) TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs. Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt( join(C.EXP_PATH, exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"), "Hyperparameters.txt")) # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary) if get_probs: bundle_specific_postprocessing = False postprocess = False Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config) Config.RESET_LAST_LAYER = False if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v4.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v2.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # xtract if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_xtract_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_xtract_v1.npz") else: raise ValueError( "bundle_definition xtract not supported in combination with this output type" ) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) #runtime on HCP data: 0.9s data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data) # runtime on HCP data: 0.5s data, transformation = data_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size, unit_test=unit_test) else: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v2.npz", "Part2": "pretrained_weights_peak_regression_part2_v2.npz", "Part3": "pretrained_weights_peak_regression_part3_v2.npz", "Part4": "pretrained_weights_peak_regression_part4_v2.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part, tract_definition=tract_definition) model = BaseModel(Config, inference=True) if single_orientation: data_loder_inference = DataLoaderInference(Config, data=data) seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: # 3 dir for Peaks -> bad results seg_xyz, _ = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) seg = direction_merger.mean_fusion_peaks(seg_xyz, nr_cpus=nr_cpus) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling: # Runtime ~4s seg = img_utils.bundle_specific_postprocessing( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:]) # runtime on HCP data: 5.1s seg = data_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus) # runtime on HCP data: 1.6s seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES) if Config.EXPERIMENT_TYPE == "peak_regression": seg = peak_utils.mask_and_normalize_peaks( seg, tract_segmentations_path, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], TOM_dilation, nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling: # Runtime ~7s for 1.25mm resolution # Runtime ~1.5s for 2mm resolution st = time.time() seg = img_utils.postprocess_segmentations( seg, dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:], blob_thr=blob_size_thr, hole_closing=None) exp_utils.print_verbose( Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def plot_tracts_matplotlib(classes, bundle_segmentations, background_img, out_dir, threshold=0.001, exp_type="tract_segmentation"): def plot_single_tract(bg, data, orientation, bundle, exp_type): if orientation == "coronal": data = data.transpose( 2, 0, 1, 3) if exp_type == "peak_regression" else data.transpose( 2, 0, 1) data = data[::-1, :, :] bg = bg.transpose(2, 0, 1)[::-1, :, :] elif orientation == "sagittal": data = data.transpose( 2, 1, 0, 3) if exp_type == "peak_regression" else data.transpose( 2, 1, 0) data = data[::-1, :, :] bg = bg.transpose(2, 1, 0)[::-1, :, :] else: # axial pass mask_voxel_coords = np.where(data != 0) if len(mask_voxel_coords) > 2 and len(mask_voxel_coords[2]) > 0: minidx = int(np.min(mask_voxel_coords[2])) maxidx = int(np.max(mask_voxel_coords[2])) + 1 mean_slice = int(np.mean([minidx, maxidx])) else: mean_slice = int(bg.shape[2] / 2) bg = bg[:, :, mean_slice] # bg = matplotlib.colors.Normalize()(bg) # project 3D to 2D image if aggregation == "mean": data = data.mean(axis=2) else: data = data.max(axis=2) plt.imshow(bg, cmap="gray") data = np.ma.masked_where(data < 0.00001, data) plt.imshow(data, cmap="autumn") # even with cmap=autumn peaks still RGB plt.title(bundle, fontsize=7) if classes.startswith("xtract"): bundles = ["cst_r", "cst_s_r", "ifo_r", "fx_l", "fx_r", "or_l", "fma"] else: if exp_type == "endings_segmentation": bundles = [ "CST_right_b", "CST_right_e", "CST_s_right_b", "CST_s_right_e", "CA_b", "CA_e" ] else: bundles = [ "CST_right", "CST_s_right", "CA", "IFO_right", "FX_left", "FX_right", "OR_left", "CC_1" ] if exp_type == "peak_regression": s = bundle_segmentations.shape bundle_segmentations = bundle_segmentations.reshape( [s[0], s[1], s[2], int(s[3] / 3), 3]) bundles = ["CST_right", "CST_s_right", "CA", "CC_1", "AF_left"] # can only use bundles from part1 aggregation = "max" cols = 4 rows = math.ceil(len(bundles) / cols) background_img = background_img[..., 0] for j, bundle in enumerate(bundles): bun = bundle.lower() if bun.startswith("ca") or bun.startswith("fx_") or bun.startswith("or_") or \ bun.startswith("cc_1") or bun.startswith("fma"): orientation = "axial" elif bun.startswith("ifo_") or bun.startswith("icp_") or bun.startswith("cst_s_") or \ bun.startswith("af_"): bundle = bundle.replace("_s", "") orientation = "sagittal" elif bun.startswith("cst_"): orientation = "coronal" else: raise ValueError("invalid bundle") bundle_idx = dataset_specific_utils.get_bundle_names( classes)[1:].index(bundle) mask_data = bundle_segmentations[:, :, :, bundle_idx] mask_data = np.copy( mask_data ) # copy data otherwise will also threshold data outside of plot function # mask_data[mask_data < threshold] = 0 mask_data[ mask_data < 0.001] = 0 # higher value better for preview, otherwise half of image just red plt.subplot(rows, cols, j + 1) plt.axis("off") plot_single_tract(background_img, mask_data, orientation, bundle, exp_type=exp_type) if exp_type == "tract_segmentation": file_name = "preview_bundle" elif exp_type == "endings_segmentation": file_name = "preview_endings" elif exp_type == "peak_regression": file_name = "preview_TOM" elif exp_type == "dm_regression": file_name = "preview_dm" else: file_name = "preview" plt.subplots_adjust(wspace=0, hspace=0) plt.savefig(join(out_dir, file_name + ".png"), bbox_inches='tight', dpi=300)