def create_multilabel_mask(Config, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"): ''' One-hot encoding of all bundles in one big image :param subject: :return: image of shape (x, y, z, nr_of_bundles + 1) ''' bundles = exp_utils.get_bundle_names(Config.CLASSES) #Masks sind immer HCP_highRes (später erst downsample) mask_ml = np.zeros((145, 174, 145, len(bundles))) background = np.ones((145, 174, 145)) # everything that contains no bundle # first bundle is background -> already considered by setting np.ones in the beginning for idx, bundle in enumerate(bundles[1:]): mask = nib.load( join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz")) mask_data = mask.get_data() # dtype: uint8 mask_ml[:, :, :, idx + 1] = mask_data background[mask_data == 1] = 0 # remove this bundle from background mask_ml[:, :, :, 0] = background return mask_ml.astype(labels_type)
def save_multilabel_img_as_multiple_files_endings(Config, img, affine, path): bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:, :, :, idx], affine) exp_utils.make_dir(join(path, "endings_segmentations")) nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
def calc_peak_dice_onlySeg(Config, y_pred, y_true): ''' Create binary mask of peaks by simple thresholding. Then calculate Dice. :param y_pred: :param y_true: :return: ''' score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST) # Resulting dice for 0.1 and 0.2 very similar y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2 y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3 f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def save_multilabel_img_as_multiple_files_endings_OLD(Config, img, affine, path, multilabel=True): ''' multilabel True: save as 1 and 2 without fourth dimension multilabel False: save with beginnings and endings combined ''' # bundles = exp_utils.get_bundle_names("20")[1:] bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0 multilabel_img = np.zeros(data.shape[:3]) if multilabel: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 2 else: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 1 img_seg = nib.Nifti1Image(multilabel_img, affine) exp_utils.make_dir(join(path, "endings")) nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
def test_tractseg_output(self): bundles = exp_utils.get_bundle_names("All")[1:] for bundle in bundles: img_ref = nib.load("tests/reference_files/bundle_segmentations/" + bundle + ".nii.gz").get_data() img_new = nib.load("examples/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_data() images_equal = np.array_equal(img_ref, img_new) self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ")")
def calc_peak_length_dice_pytorch(Config, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' Ca :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.pytorch_einsum import einsum from tractseg.libs import pytorch_utils if len(y_pred.shape) == 4: # 2D y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) else: # 3D y_true = y_true.permute(0, 2, 3, 4, 1) y_pred = y_pred.permute(0, 2, 3, 4, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' if len(a.shape) == 4: return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) else: return torch.abs(einsum('abcde,abcde->abcd', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[..., (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[..., (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = pytorch_utils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def plot_tracts(classes, bundle_segmentations, affine, out_dir, brain_mask=None): ''' By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL. On the remote Server you can do 'export DISPLAY=":0"' . (you should set the value you get if you do 'echo $DISPLAY' if you login locally on the remote server). Then all graphics will get rendered locally and not via -X. (important: graphical session needs to be running on remote server (e.g. via login locally)) (important: login needed, not just stay at login screen) ''' from dipy.viz import window from tractseg.libs import vtk_utils SMOOTHING = 10 WINDOW_SIZE = (800, 800) bundles = ["CST_right", "CA", "IFO_right"] renderer = window.Renderer() renderer.projection('parallel') rows = len(bundles) X, Y, Z = bundle_segmentations.shape[:3] for j, bundle in enumerate(bundles): i = 0 #only one method bundle_idx = exp_utils.get_bundle_names(classes)[1:].index(bundle) mask_data = bundle_segmentations[:,:,:,bundle_idx] if bundle == "CST_right": orientation = "axial" elif bundle == "CA": orientation = "axial" elif bundle == "IFO_right": orientation = "sagittal" else: orientation = "axial" #bigger: more border if orientation == "axial": border_y = -100 #-60 else: border_y = -100 x_current = X * i # column (width) y_current = rows * (Y * 2 + border_y) - (Y * 2 + border_y) * j # row (height) (starts from bottom?) plot_mask(renderer, mask_data, affine, x_current, y_current, orientation=orientation, smoothing=SMOOTHING, brain_mask=brain_mask) #Bundle label text_offset_top = -50 # 60 text_offset_side = -100 # -30 position = (0 - int(X) + text_offset_side, y_current + text_offset_top, 50) text_actor = vtk_utils.label(text=bundle, pos=position, scale=(6, 6, 6), color=(1, 1, 1)) renderer.add(text_actor) renderer.reset_camera() window.record(renderer, out_path=join(out_dir, "preview.png"), size=(WINDOW_SIZE[0], WINDOW_SIZE[1]), reset_camera=False, magnification=2)
def test_peakreg_output(self): bundles = exp_utils.get_bundle_names("All")[1:] for bundle in bundles: img_ref = nib.load("tests/reference_files/TOM/" + bundle + ".nii.gz").get_data() img_new = nib.load("examples/tractseg_output/TOM/" + bundle + ".nii.gz").get_data() # Because of floats small tolerance margin needed # Allows for difference up to 0.002 -> still fine images_equal = np.allclose(img_ref, img_new, rtol=1e-3, atol=1e-3) self.assertTrue(images_equal, "TOMs are not correct (bundle: " + bundle + ")")
def calc_peak_length_dice(Config, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs( np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1) lengths_true = np.linalg.norm(y_true_bund, axis=-1) lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.flatten() gt_binary = y_true_bund.sum(axis=-1) > 0 gt_binary = gt_binary.flatten() # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.flatten() combined = lengths_binary * angles_binary f1 = my_f1_score(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def save_multilabel_img_as_multiple_files_peaks(Config, img, affine, path, name="TOM"): bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx*3):(idx*3)+3] if Config.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) exp_utils.make_dir(join(path, name)) nib.save(img_seg, join(path, name, filename))
def calc_peak_dice(Config, y_pred, y_true, max_angle_error=[0.9]): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle(a, b): ''' Calculate the angle between two 1d-arrays (2 vectors) along the last dimension without anything further: 1->0°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) ''' return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs( np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def test_tractseg_output_SR_noPP(self): bundles = exp_utils.get_bundle_names("All")[1:] for bundle in bundles: # IFO somehow very different on travis than locally. Unclear why. All other bundles are fine. if bundle != "IFO_right": img_ref = nib.load("tests/reference_files/bundle_segmentations_SR_noPP/" + bundle + ".nii.gz").get_data() img_new = nib.load("examples/SR_noPP/tractseg_output/bundle_segmentations/" + bundle + ".nii.gz").get_data() # Processing on travis slightly different from local environment -> have to allow for small margin # images_equal = np.array_equal(img_ref, img_new) nr_differing_voxels = np.abs(img_ref - img_new).sum() if nr_differing_voxels < 5: images_equal = True else: images_equal = False self.assertTrue(images_equal, "Tract segmentations are not correct (bundle: " + bundle + ") " + "(nr of differing voxels: " + str(nr_differing_voxels) + ")")
class Config(TractSegConfig): EXP_NAME = os.path.basename(__file__).split(".")[0] DATASET_FOLDER = "HCP_preproc_all" NR_OF_GRADIENTS = 18 FEATURES_FILENAME = "32g270g_BX" P_SAMP = 0.4 CLASSES = "AutoPTX_42" NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:]) DATASET = "HCP_all" LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" LR_SCHEDULE_PATIENCE = 20 NUM_EPOCHS = 200 # 130 probably also fine
class Config(DmRegConfig): EXP_NAME = os.path.basename(__file__).split(".")[0] DATASET_FOLDER = "HCP_preproc_all" NR_OF_GRADIENTS = 18 FEATURES_FILENAME = "32g270g_BX" P_SAMP = 0.4 CLASSES = "AutoPTX_42" NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:]) # THRESHOLD = 0.001 # Final DM wil be thresholded at this value THRESHOLD = 0.0001 # use lower value so user has more choice DATASET = "HCP_all" LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" LR_SCHEDULE_PATIENCE = 20 NUM_EPOCHS = 200 # 130 probably also fine
def test_bundle_names(self): bundles = exp_utils.get_bundle_names("CST_right") self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
def plot_tracts_matplotlib(classes, bundle_segmentations, background_img, out_dir, threshold=0.001, exp_type="tract_segmentation"): def plot_single_tract(bg, data, orientation, bundle, exp_type): if orientation == "coronal": data = data.transpose( 2, 0, 1, 3) if exp_type == "peak_regression" else data.transpose( 2, 0, 1) data = data[::-1, :, :] bg = bg.transpose(2, 0, 1)[::-1, :, :] elif orientation == "sagittal": data = data.transpose( 2, 1, 0, 3) if exp_type == "peak_regression" else data.transpose( 2, 1, 0) data = data[::-1, :, :] bg = bg.transpose(2, 1, 0)[::-1, :, :] else: # axial pass mask_voxel_coords = np.where(data != 0) if len(mask_voxel_coords) > 2 and len(mask_voxel_coords[2]) > 0: minidx = int(np.min(mask_voxel_coords[2])) maxidx = int(np.max(mask_voxel_coords[2])) + 1 mean_slice = int(np.mean([minidx, maxidx])) else: mean_slice = int(bg.shape[2] / 2) bg = bg[:, :, mean_slice] # bg = matplotlib.colors.Normalize()(bg) # project 3D to 2D image # todo: this kind of projection not sensible for peak images if aggregation == "mean": data = data.mean(axis=2) else: data = data.max(axis=2) plt.imshow(bg, cmap="gray") data = np.ma.masked_where(data < 0.00001, data) plt.imshow(data, cmap="autumn") # even with cmap=autumn peaks still RGB plt.title(bundle, fontsize=7) if classes.startswith("AutoPTX"): bundles = ["cst_r", "cst_s_r", "ifo_r", "fx_l", "fx_r", "or_l", "fma"] else: bundles = [ "CST_right", "CST_s_right", "CA", "IFO_right", "FX_left", "FX_right", "OR_left", "CC_1" ] if exp_type == "peak_regression": s = bundle_segmentations.shape bundle_segmentations = bundle_segmentations.reshape( [s[0], s[1], s[2], int(s[3] / 3), 3]) print(bundle_segmentations.shape) bundles = ["CST_right", "CST_s_right", "CA", "CC_1", "AF_left"] # can only use bundles from part1 aggregation = "max" cols = 4 rows = math.ceil(len(bundles) / cols) background_img = background_img[..., 0] for j, bundle in enumerate(bundles): bun = bundle.lower() if bun.startswith("ca") or bun.startswith("fx_") or bun.startswith("or_") or \ bun.startswith("cc_1") or bun.startswith("fma"): orientation = "axial" elif bun.startswith("ifo_") or bun.startswith("icp_") or bun.startswith("cst_s_") or \ bun.startswith("af_"): bundle = bundle.replace("_s", "") orientation = "sagittal" elif bun.startswith("cst_"): orientation = "coronal" else: raise ValueError("invalid bundle") bundle_idx = exp_utils.get_bundle_names(classes)[1:].index(bundle) mask_data = bundle_segmentations[:, :, :, bundle_idx] mask_data = np.copy( mask_data ) # copy data otherwise will also threshold data outside of plot function # mask_data[mask_data < threshold] = 0 mask_data[ mask_data < 0.001] = 0 # higher value better for preview, otherwise half of image just red plt.subplot(rows, cols, j + 1) plt.axis("off") plot_single_tract(background_img, mask_data, orientation, bundle, exp_type=exp_type) plt.subplots_adjust(wspace=0, hspace=0) plt.savefig(join(out_dir, "preview.png"), bbox_inches='tight', dpi=300)
def calc_peak_dice_pytorch(Config, y_pred, y_true, max_angle_error=[0.9]): ''' Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.pytorch_einsum import einsum from tractseg.libs import pytorch_utils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs( einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = exp_utils.get_bundle_names(Config.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_threshold: Set threshold to lower for some bundles which need more sensitivity (CA, CST, FX) get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling) Config = getattr( importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config) if bundle_specific_threshold: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name), True) elif input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_dropout_v2.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz") # Config.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMirFlip", "best_weights_ep247.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz") elif input_type == "T1": if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join( C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join( C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) # brain_mask = img_utils.simple_brain_mask(data) # if Config.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data) data, transformation = dataset_utils.pad_and_scale_img_to_square_img( data, target_size=Config.INPUT_DIM[0]) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING) model = BaseModel(Config) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg, img_y = trainer.predict_img( Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, gt = direction_merger.get_seg_single_img_3_directions( Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v1.npz", "Part2": "pretrained_weights_peak_regression_part2_v1.npz", "Part3": "pretrained_weights_peak_regression_part3_v1.npz", "Part4": "pretrained_weights_peak_regression_part4_v1.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks( manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights( experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part) data_loder_inference = DataLoaderInference(Config, data=data) model = BaseModel(Config) seg, img_y = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if peak_regression_part == "All": seg_all[:, :, :, (idx * Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES + Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len( exp_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all #quite fast if bundle_specific_threshold: seg = img_utils.remove_small_peaks_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:], len_thr=0.3) else: seg = img_utils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> bad results # seg_xyz, gt = direction_merger.get_seg_single_img_3_directions(Config, model, data=data, # scale_to_world_shape=False, # only_prediction=True, # batch_size=inference_batch_size) # seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and Config.EXPERIMENT_TYPE == "tract_segmentation": seg = img_utils.probs_to_binary_bundle_specific( seg, exp_utils.get_bundle_names(Config.CLASSES)[1:]) #remove following two lines to keep super resolution seg = dataset_utils.cut_and_scale_img_back_to_original_img( seg, transformation) # quite slow seg = dataset_utils.add_original_zero_padding_again( seg, bbox, original_shape, Config.NR_OF_CLASSES) # quite slow if postprocess: seg = img_utils.postprocess_segmentations(seg, blob_thr=blob_size_thr, hole_closing=2) exp_utils.print_verbose( Config, "Took {}s".format(round(time.time() - start_time, 2))) return seg
class Config: """Settings and Hyperparameters""" EXP_MULTI_NAME = "" #CV Parent Dir name; leave empty for Single Bundle Experiment EXP_NAME = "HCP_TEST" MODEL = "UNet_Pytorch_DeepSup" # tract_segmentation / endings_segmentation / dm_regression / peak_regression EXPERIMENT_TYPE = "tract_segmentation" DIM = "2D" # 2D / 3D NUM_EPOCHS = 250 EPOCH_MULTIPLIER = 1 #2D: 1, 3D: 12 for lowRes, 3 for highRes DATA_AUGMENTATION = True DAUG_SCALE = True DAUG_NOISE = True DAUG_NOISE_VARIANCE = (0, 0.05) DAUG_ELASTIC_DEFORM = True DAUG_ALPHA = (90., 120.) DAUG_SIGMA = (9., 11.) DAUG_RESAMPLE = False # does not change validation dice (if using Gaussian_blur) -> deactivate DAUG_RESAMPLE_LEGACY = False # does not change validation dice (at least on AutoPTX) -> deactivate DAUG_GAUSSIAN_BLUR = True DAUG_BLUR_SIGMA = (0, 1) DAUG_ROTATE = False DAUG_ROTATE_ANGLE = (-0.2, 0.2) # rotation: 2*np.pi = 360 degree (-> 0.4 ~ 22 degree, 0.2 ~ 11 degree)) DAUG_MIRROR = False DAUG_FLIP_PEAKS = False SPATIAL_TRANSFORM = "SpatialTransform" # SpatialTransform / SpatialTransformPeaks # 1.0 slightly less overfitting than 0.4 but not much ("break-even" 20epochs later) # 1.0: CPU bottleneck, 0.4: CPU not 100% all the time anymore, but still GPU utility not 100% # 1.0: clearly more complete CA+FX on nonHCP than 0.2 P_SAMP = 1.0 # use 1.0 for final model DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - " \ "DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)" DATASET = "HCP" # HCP / HCP_32g / Schizo RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm) # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz FEATURES_FILENAME = "12g90g270g" LABELS_FILENAME = "" # autofilled LOSS_FUNCTION = "default" # default / soft_batch_dice OPTIMIZER = "Adamax" CLASSES = "All" # All / 11 / 20 / CST_right NR_OF_GRADIENTS = 9 NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:]) # NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(CLASSES)[1:]) INPUT_DIM = None # (80, 80) / (144, 144) LOSS_WEIGHT = None # None: no weighting LOSS_WEIGHT_LEN = -1 # -1: constant over all epochs SLICE_DIRECTION = "y" # x, y, z (combined needs z) TRAINING_SLICE_DIRECTION = "xyz" # y / xyz INFO = "-" BATCH_NORM = False WEIGHT_DECAY = 0 USE_DROPOUT = False DROPOUT_SAMPLING = False # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset" DATASET_FOLDER = "HCP_preproc" # HCP / Schizo LABELS_FOLDER = "bundle_masks" # bundle_masks / bundle_masks_dm MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME) EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME) # default path BATCH_SIZE = 47 #Peak Prediction: 44 #Pytorch: 50 #Lasagne: 56 #Lasagne combined: 42 #Pytorch UpSample: 56 LEARNING_RATE = 0.001 # 0.002 #LR find: 0.000143 ? # 0.001 LR_SCHEDULE = True LR_SCHEDULE_MODE = "min" # "min" / "max" LR_SCHEDULE_PATIENCE = 20 UNET_NR_FILT = 64 LOAD_WEIGHTS = False # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz") WEIGHTS_PATH = "" # if empty string: autoloading the best_weights in get_best_weights_path() TYPE = "single_direction" # single_direction / combined CV_FOLD = 0 VALIDATE_SUBJECTS = [] TRAIN_SUBJECTS = [] TEST_SUBJECTS = [] TRAIN = True TEST = True SEGMENT = False GET_PROBS = False OUTPUT_MULTIPLE_FILES = False RESET_LAST_LAYER = False UPSAMPLE_TYPE = "bilinear" # bilinear / nearest BEST_EPOCH_SELECTION = "f1" # f1 / loss METRIC_TYPES = ["loss", "f1_macro"] FP16 = True # Peak_regression specific PEAK_DICE_THR = [0.95] PEAK_DICE_LEN_THR = 0.05 FLIP_OUTPUT_PEAKS = True # flip peaks along z axis to make them compatible with MITK # For TractSeg.py application PREDICT_IMG = False PREDICT_IMG_OUTPUT = None TRACTSEG_DIR = "tractseg_output" KEEP_INTERMEDIATE_FILES = False CSD_RESOLUTION = "LOW" # HIGH / LOW NR_CPUS = -1 # Rarly changed: LABELS_TYPE = "int" THRESHOLD = 0.5 # Binary: 0.5, Regression: 0.01 ? TEST_TIME_DAUG = False USE_VISLOGGER = False #only works with Python 3 SAVE_WEIGHTS = True SEG_INPUT = "Peaks" # Gradients/ Peaks NR_SLICES = 1 # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction() PRINT_FREQ = 20 #20 NORMALIZE_DATA = True NORMALIZE_PER_CHANNEL = False BEST_EPOCH = 0 VERBOSE = True CALC_F1 = True ONLY_VAL = False
class Config: """Settings and Hyperparameters""" EXP_MULTI_NAME = "" #CV Parent Dir name; leave empty for Single Bundle Experiment EXP_NAME = "HCP_TEST" MODEL = "UNet_Pytorch" # tract_segmentation / endings_segmentation / dm_regression / peak_regression EXPERIMENT_TYPE = "tract_segmentation" DIM = "2D" # 2D / 3D NUM_EPOCHS = 250 EPOCH_MULTIPLIER = 1 #2D: 1, 3D: 12 for lowRes, 3 for highRes DATA_AUGMENTATION = False DAUG_SCALE = True DAUG_NOISE = True DAUG_ELASTIC_DEFORM = True DAUG_RESAMPLE = True DAUG_ROTATE = False DAUG_MIRROR = False DAUG_FLIP_PEAKS = False DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - " \ "DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)" DATASET = "HCP" # HCP / HCP_32g / Schizo RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm) # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz FEATURES_FILENAME = "12g90g270g" LABELS_FILENAME = "" # autofilled LOSS_FUNCTION = "default" # default / soft_batch_dice OPTIMIZER = "Adamax" CLASSES = "All" # All / 11 / 20 / CST_right NR_OF_GRADIENTS = 9 NR_OF_CLASSES = len(exp_utils.get_bundle_names(CLASSES)[1:]) # NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(CLASSES)[1:]) INPUT_DIM = None # (80, 80) / (144, 144) LOSS_WEIGHT = 1 # 1: no weighting LOSS_WEIGHT_LEN = -1 # -1: constant over all epochs SLICE_DIRECTION = "y" # x, y, z (combined needs z) TRAINING_SLICE_DIRECTION = "xyz" # y / xyz INFO = "-" BATCH_NORM = False WEIGHT_DECAY = 0 USE_DROPOUT = False DROPOUT_SAMPLING = False # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset" DATASET_FOLDER = "HCP" # HCP / Schizo LABELS_FOLDER = "bundle_masks" # bundle_masks / bundle_masks_dm MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME) EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME) # default path BATCH_SIZE = 47 #Peak Prediction: 44 #Pytorch: 50 #Lasagne: 56 #Lasagne combined: 42 #Pytorch UpSample: 56 LEARNING_RATE = 0.001 # 0.002 #LR find: 0.000143 ? # 0.001 LR_SCHEDULE = False UNET_NR_FILT = 64 LOAD_WEIGHTS = False # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz") WEIGHTS_PATH = "" # if empty string: autoloading the best_weights in get_best_weights_path() TYPE = "single_direction" # single_direction / combined CV_FOLD = 0 VALIDATE_SUBJECTS = [] TRAIN_SUBJECTS = [] TEST_SUBJECTS = [] TRAIN = True TEST = True SEGMENT = False GET_PROBS = False OUTPUT_MULTIPLE_FILES = False RESET_LAST_LAYER = False UPSAMPLE_TYPE = "bilinear" # bilinear / nearest # Peak_regression specific PEAK_DICE_THR = [0.95] PEAK_DICE_LEN_THR = 0.05 FLIP_OUTPUT_PEAKS = True # flip peaks along z axis to make them compatible with MITK # For TractSeg.py application PREDICT_IMG = False PREDICT_IMG_OUTPUT = None TRACTSEG_DIR = "tractseg_output" KEEP_INTERMEDIATE_FILES = False CSD_RESOLUTION = "LOW" # HIGH / LOW NR_CPUS = -1 # Rarly changed: LABELS_TYPE = np.int16 # Binary: np.int16, Regression: np.float32 THRESHOLD = 0.5 # Binary: 0.5, Regression: 0.01 ? TEST_TIME_DAUG = False USE_VISLOGGER = False #only works with Python 3 SAVE_WEIGHTS = True SEG_INPUT = "Peaks" # Gradients/ Peaks NR_SLICES = 1 # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction() PRINT_FREQ = 20 #20 NORMALIZE_DATA = True NORMALIZE_PER_CHANNEL = False BEST_EPOCH = 0 VERBOSE = True CALC_F1 = True
def run_tractseg(data, output_type="tract_segmentation", single_orientation=False, dropout_sampling=False, threshold=0.5, bundle_specific_postprocessing=True, get_probs=False, peak_threshold=0.1, postprocess=False, peak_regression_part="All", input_type="peaks", blob_size_thr=50, nr_cpus=-1, verbose=False, manual_exp_name=None, inference_batch_size=1, tract_definition="TractQuerier+", bedpostX_input=False, tract_segmentations_path=None, TOM_dilation=1): """ Run TractSeg Args: data: input peaks (4D numpy array with shape [x,y,z,9]) output_type: TractSeg can segment not only bundles, but also the end regions of bundles. Moreover it can create Tract Orientation Maps (TOM). 'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles). 'endings_segmentation': Segmentation of bundle end regions (72 bundles). 'TOM': Tract Orientation Maps (20 bundles). single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion. dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) threshold: Threshold for converting probability map to binary map bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete get_probs: Output raw probability map instead of binary map peak_threshold: All peaks shorter than peak_threshold will be set to zero postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all 72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory load. input_type: Always set to "peaks" blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in this threshold will be removed. nr_cpus: Number of CPUs to use. -1 means all available CPUs. verbose: Show debugging infos manual_exp_name: Name of experiment if do not want to use pretrained model but your own one inference_batch_size: batch size (higher: a bit faster but needs more RAM) tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their cortical start and end region. 'AutoPTX' defines tracts mainly by ROIs in white matter. bedpostX_input: Input peaks are generated by bedpostX tract_segmentations_path: todo TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs. Returns: 4D numpy array with the output of tractseg for tract_segmentation: [x, y, z, nr_of_bundles] for endings_segmentation: [x, y, z, 2*nr_of_bundles] for TOM: [x, y, z, 3*nr_of_bundles] """ start_time = time.time() if manual_exp_name is None: config = get_config_name(input_type, output_type, dropout_sampling=dropout_sampling, tract_definition=tract_definition) Config = getattr(importlib.import_module("tractseg.experiments.pretrained_models." + config), "Config")() else: Config = exp_utils.load_config_from_txt(join(C.EXP_PATH, exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"), "Hyperparameters.txt")) Config = exp_utils.get_correct_labels_type(Config) Config.VERBOSE = verbose Config.TRAIN = False Config.TEST = False Config.SEGMENT = False Config.GET_PROBS = get_probs Config.LOAD_WEIGHTS = True Config.DROPOUT_SAMPLING = dropout_sampling Config.THRESHOLD = threshold Config.NR_CPUS = nr_cpus Config.INPUT_DIM = exp_utils.get_correct_input_dim(Config) if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: Config.GET_PROBS = True if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression": Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(join(C.EXP_PATH, manual_exp_name), True) else: if tract_definition == "TractQuerier+": if input_type == "peaks": if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING: Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_dropout_v2.npz") elif Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_v2.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v3.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_dm_regression_v1.npz") else: # T1 if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif Config.EXPERIMENT_TYPE == "endings_segmentation": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_endings_segmentation_v1.npz") elif Config.EXPERIMENT_TYPE == "peak_regression": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_peak_regression_v1.npz") else: # AutoPTX if Config.EXPERIMENT_TYPE == "tract_segmentation": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_tract_segmentation_aPTX_v1.npz") elif Config.EXPERIMENT_TYPE == "dm_regression": Config.WEIGHTS_PATH = join(C.WEIGHTS_DIR, "pretrained_weights_dm_regression_aPTX_v1.npz") else: raise ValueError("bundle_definition AutoPTX not supported in combination with this output type") #todo: remove when aPTX weights are loaded automatically if not os.path.exists(Config.WEIGHTS_PATH): raise FileNotFoundError("Could not find weights file: {}".format(Config.WEIGHTS_PATH)) if Config.VERBOSE: print("Hyperparameters:") exp_utils.print_Configs(Config) data = np.nan_to_num(data) # brain_mask = img_utils.simple_brain_mask(data) # if Config.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") #runtime on HCP data: 0.9s data, seg_None, bbox, original_shape = dataset_utils.crop_to_nonzero(data) # runtime on HCP data: 0.5s data, transformation = dataset_utils.pad_and_scale_img_to_square_img(data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \ Config.EXPERIMENT_TYPE == "dm_regression": print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.NR_OF_CLASSES = len(exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights(experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING) model = BaseModel(Config, inference=True) if single_orientation: # mainly needed for testing because of less RAM requirements data_loder_inference = DataLoaderInference(Config, data=data) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=False, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=True) else: seg = direction_merger.mean_fusion(Config.THRESHOLD, seg_xyz, probs=False) elif Config.EXPERIMENT_TYPE == "peak_regression": weights = { "Part1": "pretrained_weights_peak_regression_part1_v2.npz", "Part2": "pretrained_weights_peak_regression_part2_v2.npz", "Part3": "pretrained_weights_peak_regression_part3_v2.npz", "Part4": "pretrained_weights_peak_regression_part4_v2.npz", } if peak_regression_part == "All": parts = ["Part1", "Part2", "Part3", "Part4"] seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2], Config.NR_OF_CLASSES * 3)) else: parts = [peak_regression_part] Config.CLASSES = "All_" + peak_regression_part Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:]) for idx, part in enumerate(parts): if manual_exp_name is not None: manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(manual_exp_name, part) Config.WEIGHTS_PATH = exp_utils.get_best_weights_path( join(C.EXP_PATH, manual_exp_name_peaks), True) else: Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part]) print("Loading weights from: {}".format(Config.WEIGHTS_PATH)) Config.CLASSES = "All_" + part Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:]) utils.download_pretrained_weights(experiment_type=Config.EXPERIMENT_TYPE, dropout_sampling=Config.DROPOUT_SAMPLING, part=part) model = BaseModel(Config, inference=True) if single_orientation: data_loder_inference = DataLoaderInference(Config, data=data) seg, _ = trainer.predict_img(Config, model, data_loder_inference, probs=True, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) else: # 3 dir for Peaks -> bad results seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(Config, model, data=data, scale_to_world_shape=False, only_prediction=True, batch_size=inference_batch_size) seg = direction_merger.mean_fusion_peaks(seg_xyz) if peak_regression_part == "All": seg_all[:, :, :, (idx*Config.NR_OF_CLASSES) : (idx*Config.NR_OF_CLASSES+Config.NR_OF_CLASSES)] = seg if peak_regression_part == "All": Config.CLASSES = "All" Config.NR_OF_CLASSES = 3 * len(exp_utils.get_bundle_names(Config.CLASSES)[1:]) seg = seg_all #quite fast # if bundle_specific_threshold: # seg = peak_utils.remove_small_peaks_bundle_specific(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:], # len_thr=0.3) # else: # seg = peak_utils.remove_small_peaks(seg, len_thr=peak_threshold) if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing: # Runtime ~4s seg = img_utils.bundle_specific_postprocessing(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:]) # runtime on HCP data: 5.1s seg = dataset_utils.cut_and_scale_img_back_to_original_img(seg, transformation, nr_cpus=nr_cpus) # runtime on HCP data: 1.6s seg = dataset_utils.add_original_zero_padding_again(seg, bbox, original_shape, Config.NR_OF_CLASSES) if Config.EXPERIMENT_TYPE == "peak_regression": seg = peak_utils.mask_and_normalize_peaks(seg, tract_segmentations_path, exp_utils.get_bundle_names(Config.CLASSES)[1:], TOM_dilation, nr_cpus=nr_cpus) if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess: # Runtime ~7s for 1.25mm resolution # Runtime ~1.5s for 2mm resolution st = time.time() seg = img_utils.postprocess_segmentations(seg, exp_utils.get_bundle_names(Config.CLASSES)[1:], blob_thr=blob_size_thr, hole_closing=None) print("took: {}".format(time.time() - st)) exp_utils.print_verbose(Config, "Took {}s".format(round(time.time() - start_time, 2))) return seg