def sysu_media_wmh_segmentation(flair, t1=None, do_preprocessing=True, use_ensemble=True, use_axial_slices_only=True, antsxnet_cache_directory=None, verbose=False): """ Perform WMH segmentation using the winning submission in the MICCAI 2017 challenge by the sysu_media team using FLAIR or T1/FLAIR. The MICCAI challenge is discussed in https://pubmed.ncbi.nlm.nih.gov/30908194/ with the sysu_media's team entry is discussed in https://pubmed.ncbi.nlm.nih.gov/30125711/ with the original implementation available here: https://github.com/hongweilibran/wmh_ibbmTum Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). do_preprocessing : boolean perform n4 bias correction? use_ensemble : boolean check whether to use all 3 sets of weights. use_axial_slices_only : boolean If True, use original implementation which was trained on axial slices. If False, use ANTsXNet variant implementation which applies the slice-by-slice models to all 3 dimensions and averages the results. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import brain_extraction from ..utilities import crop_image_center from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size if flair.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ flair_preprocessed = flair if do_preprocessing == True: flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] number_of_channels = 1 if t1 is not None: t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] number_of_channels = 2 ################################ # # Estimate mask # ################################ brain_mask = None if verbose == True: print("Estimating brain mask.") if t1 is not None: brain_mask = brain_extraction(t1, modality="t1") else: brain_mask = brain_extraction(flair, modality="flair") reference_image = ants.make_image((200, 200, 200), voxval=1, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = ants.get_center_of_mass(reference_image) center_of_mass_image = ants.get_center_of_mass(brain_mask) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) flair_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, flair_preprocessed, reference_image) brain_mask_warped = ants.threshold_image( ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image), 0.5, 1.1, 1, 0) if t1 is not None: t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image) ################################ # # Gaussian normalize intensity based on brain mask # ################################ mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean() std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std() flair_preprocessed_warped = (flair_preprocessed_warped - mean_flair) / std_flair if number_of_channels == 2: mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean() std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std() t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 unet_models = list() for i in range(number_of_models): if number_of_channels == 1: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairOnlyModel" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairT1Model" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_models.append( create_sysu_media_unet_model_2d((200, 200, number_of_channels))) unet_models[i].load_weights(weights_file_name) ################################ # # Extract slices # ################################ dimensions_to_predict = [2] if use_axial_slices_only == False: dimensions_to_predict = list(range(3)) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed_warped.shape[ dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size( ants.slice_image(flair_preprocessed_warped, dimensions_to_predict[d], i), (200, 200)) batchX[slice_count, :, :, 0] = flair_slice.numpy() if number_of_channels == 2: t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), (200, 200)) batchX[slice_count, :, :, 1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_models[0].predict(batchX, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction += unet_models[i].predict(batchX, verbose=verbose) prediction /= number_of_models permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed_warped.shape[ dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( flair_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice + 1 probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, flair) return (probability_image)
def lung_extraction(image, modality="proton", antsxnet_cache_directory=None, verbose=False): """ Perform proton or ct lung extraction using U-net. Arguments --------- image : ANTsImage input image modality : string Modality image type. Options include "ct", "proton", "protonLobes", "maskLobes", and "ventilation". antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Dictionary of ANTs segmentation and probability images. Example ------- >>> output = lung_extraction(lung_image, modality="proton") """ from ..architectures import create_unet_model_2d from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import pad_or_crop_image_to_size if image.dimension != 3: raise ValueError( "Image dimension must be 3." ) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_mods = [modality] channel_size = len(image_mods) weights_file_name = None unet_model = None if modality == "proton": weights_file_name = get_pretrained_network("protonLungMri", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left_lung", "right_lung") number_of_classification_labels = len(classes) reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5)) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) batchX = np.expand_dims(warped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability mask to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) if modality == "protonLobes" or modality == "maskLobes": reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) channel_size = 1 + len(priors_image_list) number_of_classification_labels = 1 + len(priors_image_list) unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters_at_base_layer=16, number_of_layers=4, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",)) if modality == "protonLobes": penultimate_layer = unet_model.layers[-2].output outputs2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2]) weights_file_name = get_pretrained_network("protonLobes", antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network("maskLobes", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) warped_array = warped_image.numpy() if modality == "protonLobes": warped_array = (warped_array - warped_array.mean()) / warped_array.std() else: warped_array[warped_array != 0] = 1 batchX = np.zeros((1, *warped_array.shape, channel_size)) batchX[0,:,:,:,0] = warped_array for i in range(len(priors_image_list)): batchX[0,:,:,:,i+1] = priors_image_list[i].numpy() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): if modality == "protonLobes": probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) else: probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability images to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] if modality == "protonLobes": whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]), origin=origin, spacing=spacing, direction=direction) whole_lung_mask = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), whole_lung_mask, image) return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array, 'whole_lung_mask_image' : whole_lung_mask} return(return_dict) else: return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) elif modality == "ct": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess CT image.") def closest_simplified_direction_matrix(direction): closest = np.floor(np.abs(direction) + 0.5) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(image.direction) reference_image_size = (128, 128, 128) ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0) ct_preprocessed[ct_preprocessed < -1000] = -1000 ct_preprocessed[ct_preprocessed > 400] = 400 ct_preprocessed.set_direction(simplified_direction) ct_preprocessed.set_origin((0, 0, 0)) ct_preprocessed.set_spacing((1, 1, 1)) ################################ # # Reorient image # ################################ reference_image = ants.make_image(reference_image_size, voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1)) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) / (ct_preprocessed.max() - ct_preprocessed.min())) ct_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor") ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) / (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5 ################################ # # Build models and load weights # ################################ if verbose == True: print("Build model and load weights.") weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left lung", "right lung", "airways") number_of_classification_labels = len(classes) luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors"))) for i in range(len(luna16_priors)): luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True) channel_size = len(luna16_priors) + 1 unet_model = create_unet_model_3d((*reference_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating",)) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") batchX = np.zeros((1, *reference_image_size, channel_size)) batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy() for i in range(len(luna16_priors)): batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5 predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(number_of_classification_labels): if verbose == True: print("Reconstructing image", classes[i]) probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]), origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing, direction=ct_preprocessed_warped.direction) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed) probability_image = ants.resample_image(probability_image, resample_params=image.shape, use_voxels=True, interp_type=0) probability_image = ants.copy_image_info(image, probability_image) probability_images.append(probability_image) image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict) elif modality == "ventilation": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess ventilation image.") template_size = (256, 256) image_modalities = ("Ventilation",) channel_size = len(image_modalities) preprocessed_image = (image - image.mean()) / image.std() ants.set_direction(preprocessed_image, np.identity(3)) ################################ # # Build models and load weights # ################################ unet_model = create_unet_model_2d((*template_size, channel_size), number_of_outputs=1, mode='sigmoid', number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0) if verbose == True: print("Whole lung mask: retrieving model weights.") weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ spacing = ants.get_spacing(preprocessed_image) dimensions_to_predict = (spacing.index(max(spacing)),) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,0] = ventilation_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) probability_image = ants.image_clone(image) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices,:,:,0] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(image, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), image.shape)) probability_image = probability_image + (prediction_image - probability_image) / (d + 1) current_start_slice = current_end_slice + 1 return(probability_image) else: return ValueError("Unrecognized modality.")
def brain_age(t1, do_preprocessing=True, number_of_simulations=0, sd_affine=0.01, antsxnet_cache_directory=None, verbose=False): """ Estimate BrainAge from a T1-weighted MR image using the DeepBrainNet architecture and weights described here: https://github.com/vishnubashyam/DeepBrainNet and described in the following article: https://academic.oup.com/brain/article-abstract/doi/10.1093/brain/awaa160/5863667?redirectedFrom=fulltext Preprocessing on the training data consisted of: * n4 bias correction, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. number_of_simulations : integer Number of random affine perturbations to transform the input. sd_affine : float Define the standard deviation of the affine transformation parameter. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> deep = brain_age(image) >>> print("Predicted age: ", deep['predicted_age'] """ from ..utilities import preprocess_brain_image from ..utilities import get_pretrained_network from ..utilities import randomly_transform_image_data if t1.dimension != 3: raise ValueError( "Image dimension must be 3." ) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] t1_preprocessed = (t1_preprocessed - t1_preprocessed.min()) / (t1_preprocessed.max() - t1_preprocessed.min()) ################################ # # Load model and weights # ################################ model_weights_file_name = get_pretrained_network("brainAgeDeepBrainNet", antsxnet_cache_directory=antsxnet_cache_directory) model = keras.models.load_model(model_weights_file_name) # The paper only specifies that 80 slices are used for prediction. I just picked # a reasonable range spanning the center of the brain which_slices = list(range(45, 125)) batchX = np.zeros((len(which_slices), *t1_preprocessed.shape[0:2], 3)) input_image = list() input_image.append(t1_preprocessed) input_image_list = list() input_image_list.append(input_image) if number_of_simulations > 0: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=input_image_list, number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear') brain_age_per_slice = None for i in range(number_of_simulations + 1): batch_image = t1_preprocessed if i > 0: batch_image = data_augmentation['simulated_images'][i-1][0] for j in range(len(which_slices)): slice = (ants.slice_image(batch_image, axis=2, idx=which_slices[j])).numpy() batchX[j,:,:,0] = slice batchX[j,:,:,1] = slice batchX[j,:,:,2] = slice if verbose == True: print("Brain age (DeepBrainNet): predicting brain age per slice (batch = ", i, ")") if i == 0: brain_age_per_slice = model.predict(batchX, verbose=verbose) else: prediction = model.predict(batchX, verbose=verbose) brain_age_per_slice = brain_age_per_slice + (prediction - brain_age_per_slice) / (i+1) predicted_age = statistics.median(brain_age_per_slice)[0] return_dict = {'predicted_age' : predicted_age, 'brain_age_per_slice' : brain_age_per_slice} return(return_dict)
def el_bicho(ventilation_image, mask, use_coarse_slices_only=True, antsxnet_cache_directory=None, verbose=False): """ Perform functional lung segmentation using hyperpolarized gases. https://pubmed.ncbi.nlm.nih.gov/30195415/ Arguments --------- ventilation_image : ANTsImage input ventilation image. mask : ANTsImage input mask. use_coarse_slices_only : boolean If True, apply network only in the dimension of greatest slice thickness. If False, apply to all dimensions and average the results. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Ventilation segmentation and corresponding probability images Example ------- >>> image = ants.image_read("ventilation.nii.gz") >>> mask = ants.image_read("mask.nii.gz") >>> lung_seg = el_bicho(image, mask, use_coarse_slices=True, verbose=False) """ from ..architectures import create_unet_model_2d from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size if ventilation_image.dimension != 3: raise ValueError("Image dimension must be 3.") if ventilation_image.shape != mask.shape: raise ValueError( "Ventilation image and mask size are not the same size.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess image # ################################ template_size = (256, 256) classes = (0, 1, 2, 3, 4) number_of_classification_labels = len(classes) image_modalities = ("Ventilation", "Mask") channel_size = len(image_modalities) preprocessed_image = (ventilation_image - ventilation_image.mean()) / ventilation_image.std() ants.set_direction(preprocessed_image, np.identity(3)) mask_identity = ants.image_clone(mask) ants.set_direction(mask_identity, np.identity(3)) ################################ # # Build models and load weights # ################################ unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=1e-5, additional_options=("attentionGating")) if verbose == True: print("El Bicho: retrieving model weights.") weights_file_name = get_pretrained_network( "elBicho", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ spacing = ants.get_spacing(preprocessed_image) dimensions_to_predict = (spacing.index(max(spacing)), ) if use_coarse_slices_only == False: dimensions_to_predict = list(range(3)) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += preprocessed_image.shape[ dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): ventilation_slice = pad_or_crop_image_to_size( ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size) batchX[slice_count, :, :, 0] = ventilation_slice.numpy() mask_slice = pad_or_crop_image_to_size( ants.slice_image(mask_identity, dimensions_to_predict[d], i), template_size) batchX[slice_count, :, :, 1] = mask_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) probability_images = list() for l in range(number_of_classification_labels): probability_images.append(ants.image_clone(mask) * 0) current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + preprocessed_image.shape[ dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) for l in range(number_of_classification_labels): prediction_per_dimension = prediction[which_batch_slices, :, :, l] prediction_array = np.transpose( np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( ventilation_image, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), ventilation_image.shape)) probability_images[l] = probability_images[l] + ( prediction_image - probability_images[l]) / (d + 1) current_start_slice = current_end_slice + 1 ################################ # # Convert probability images to segmentation # ################################ image_matrix = ants.image_list_to_matrix( probability_images[1:(len(probability_images))], mask * 0 + 1) background_foreground_matrix = np.stack([ ants.image_list_to_matrix([probability_images[0]], mask * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0) ]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), mask * 0 + 1)[0] return_dict = { 'segmentation_image': segmentation_image, 'probability_images': probability_images } return (return_dict)
def claustrum_segmentation(t1, do_preprocessing=True, use_ensemble=True, antsxnet_cache_directory=None, verbose=False): """ Claustrum segmentation Described here: https://arxiv.org/abs/2008.03465 with the implementation available at: https://github.com/hongweilibran/claustrum_multi_view Arguments --------- t1 : ANTsImage input 3-D T1 brain image. do_preprocessing : boolean perform n4 bias correction. use_ensemble : boolean check whether to use all 3 sets of weights. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Claustrum segmentation probability image Example ------- >>> image = ants.image_read("t1.nii.gz") >>> probability_mask = claustrum_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import brain_extraction from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_size = (180, 180) ################################ # # Preprocess images # ################################ number_of_channels = 1 t1_preprocessed = ants.image_clone(t1) brain_mask = ants.threshold_image(t1, 0, 0, 0, 1) if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] brain_mask = t1_preprocessing["brain_mask"] reference_image = ants.make_image((170, 256, 256), voxval=1, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = ants.get_center_of_mass(reference_image) center_of_mass_image = ants.get_center_of_mass(brain_mask) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image) brain_mask_warped = ants.threshold_image( ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image), 0.5, 1.1, 1, 0) ################################ # # Gaussian normalize intensity based on brain mask # ################################ mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean() std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std() t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 if verbose == True: print("Claustrum: retrieving axial model weights.") unet_axial_models = list() for i in range(number_of_models): weights_file_name = get_pretrained_network( "claustrum_axial_" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_axial_models.append( create_sysu_media_unet_model_2d((*image_size, number_of_channels), anatomy="claustrum")) unet_axial_models[i].load_weights(weights_file_name) if verbose == True: print("Claustrum: retrieving coronal model weights.") unet_coronal_models = list() for i in range(number_of_models): weights_file_name = get_pretrained_network( "claustrum_coronal_" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_coronal_models.append( create_sysu_media_unet_model_2d((*image_size, number_of_channels), anatomy="claustrum")) unet_coronal_models[i].load_weights(weights_file_name) ################################ # # Extract slices # ################################ dimensions_to_predict = [1, 2] batch_coronal_X = np.zeros( (t1_preprocessed_warped.shape[1], *image_size, number_of_channels)) batch_axial_X = np.zeros( (t1_preprocessed_warped.shape[2], *image_size, number_of_channels)) for d in range(len(dimensions_to_predict)): number_of_slices = t1_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), image_size) if dimensions_to_predict[d] == 1: batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1) else: batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy()) ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Coronal prediction.") prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction_coronal += unet_coronal_models[i].predict( batch_coronal_X, verbose=verbose) prediction_coronal /= number_of_models for i in range(t1_preprocessed_warped.shape[1]): prediction_coronal[i, :, :, 0] = np.rot90( np.squeeze(prediction_coronal[i, :, :, 0])) if verbose == True: print("Axial prediction.") prediction_axial = unet_axial_models[0].predict(batch_axial_X, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction_axial += unet_axial_models[i].predict(batch_axial_X, verbose=verbose) prediction_axial /= number_of_models for i in range(t1_preprocessed_warped.shape[2]): prediction_axial[i, :, :, 0] = np.rot90(np.squeeze(prediction_axial[i, :, :, 0]), k=-1) if verbose == True: print("Restack image and transform back to native space.") permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0 for d in range(len(dimensions_to_predict)): which_batch_slices = range( t1_preprocessed_warped.shape[dimensions_to_predict[d]]) prediction_per_dimension = None if dimensions_to_predict[d] == 1: prediction_per_dimension = prediction_coronal[ which_batch_slices, :, :, :] else: prediction_per_dimension = prediction_axial[ which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( t1_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), t1_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0) return (probability_image)
'sub-10', 'sub-11', 'sub-12', 'sub-13', 'sub-14', 'sub-15', 'sub-16' ]) num_subs = 15 mask_thresh = .1 gm_mask = ants.image_read( os.path.join( root, 'sub-01/anat/sub-01_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz') ) funcROI = ants.image_read( os.path.join( root, 'sub-01/func/sub-01_task-sherlockPart1_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz' )) epi3d = ants.slice_image(funcROI, axis=3, idx=0) reg = ants.registration(fixed=epi3d, moving=gm_mask, type_of_transform='Rigid', reg_iterations=[100, 100, 20]) gm_mask_size = reg['warpedmovout'] for sub in sublist: sub_dir = root + sub + '/func/' data_dir = root + sub + '/anat/' sub_mask = ants.image_read( os.path.join( data_dir, sub + '_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz')) funcROI = ants.image_read( os.path.join(
def ew_david(flair, t1, do_preprocessing=True, do_slicewise=True, antsxnet_cache_directory=None, verbose=False): """ Perform White matter hypterintensity probabilistic segmentation using deep learning Preprocessing on the training data consisted of: * n4 bias correction, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set \code{doPreprocessing = TRUE} Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). do_preprocessing : boolean perform n4 bias correction? do_slicewise : boolean apply 2-D modal along direction of maximal slice thickness. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_unet_model_2d from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches from ..utilities import pad_or_crop_image_to_size if flair.dimension != 3: raise ValueError( "Image dimension must be 3." ) if t1.dimension != 3: raise ValueError( "Image dimension must be 3." ) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" if do_slicewise == False: ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_preprocessing = None if do_preprocessing == True: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), do_brain_extraction=True, template="croppedMni152", template_transform_type="AffineFast", do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] flair_preprocessed = flair if do_preprocessing == True: flair_preprocessing = preprocess_brain_image(flair, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed, moving=flair_preprocessing["preprocessed_image"], transformlist=t1_preprocessing['template_transforms']['fwdtransforms']) flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask'] ################################ # # Build model and load weights # ################################ patch_size = (112, 112, 112) stride_length = (t1_preprocessed.shape[0] - patch_size[0], t1_preprocessed.shape[1] - patch_size[1], t1_preprocessed.shape[2] - patch_size[2]) classes = ("background", "wmh" ) number_of_classification_labels = len(classes) labels = (0, 1) image_modalities = ("T1", "FLAIR") channel_size = len(image_modalities) unet_model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs = number_of_classification_labels, number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0, convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), weight_decay = 1e-5, nn_unet_activation_style=False, add_attention_gating=True) weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("ew_david: prediction.") batchX = np.zeros((8, *patch_size, channel_size)) t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[:,:,:,:,0] = t1_patches flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std() flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[:,:,:,:,1] = flair_patches predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(len(labels)): print("Reconstructing image", classes[i]) reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i], domain_image=t1_preprocessed, stride_length=stride_length) if do_preprocessing == True: probability_images.append(ants.apply_transforms(fixed=t1, moving=reconstructed_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(reconstructed_image) return(probability_images[1]) else: # do_slicewise ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_preprocessing = None if do_preprocessing == True: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] flair_preprocessed = flair if do_preprocessing == True: flair_preprocessing = preprocess_brain_image(flair, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] resampling_params = list(ants.get_spacing(flair_preprocessed)) do_resampling = False for d in range(len(resampling_params)): if resampling_params[d] < 0.8: resampling_params[d] = 1.0 do_resampling = True resampling_params = tuple(resampling_params) if do_resampling: flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0) t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0) flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std() t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() ################################ # # Build model and load weights # ################################ template_size = (256, 256) classes = ("background", "wmh" ) number_of_classification_labels = len(classes) labels = (0, 1) image_modalities = ("T1", "FLAIR") channel_size = len(image_modalities) unet_model = create_unet_model_2d((*template_size, channel_size), number_of_outputs = number_of_classification_labels, number_of_layers = 4, number_of_filters_at_base_layer = 32, dropout_rate = 0.0, convolution_kernel_size = (3, 3), deconvolution_kernel_size = (2, 2), weight_decay = 1e-5, nn_unet_activation_style=True, add_attention_gating=True) if verbose == True: print("ewDavid: retrieving model weights.") weights_file_name = get_pretrained_network("ewDavidWmhSegmentationSlicewiseWeights", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ use_coarse_slices_only = True spacing = ants.get_spacing(flair_preprocessed) dimensions_to_predict = (spacing.index(max(spacing)),) if use_coarse_slices_only == False: dimensions_to_predict = list(range(3)) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,0] = flair_slice.numpy() t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed.shape[dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices,:,:,1] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(flair_preprocessed, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed.shape)) prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice + 1 if do_resampling: prediction_image_average = ants.resample_image_to_target(prediction_image_average, flair) return(prediction_image_average)
def tid_neural_image_assessment(image, mask=None, patch_size=101, stride_length=None, padding_size=0, dimensions_to_predict=0, antsxnet_cache_directory=None, which_model="tidsQualityAssessment", verbose=False): """ Perform MOS-based assessment of an image. Use a ResNet architecture to estimate image quality in 2D or 3D using subjective QC image databases described in https://www.sciencedirect.com/science/article/pii/S0923596514001490 or https://doi.org/10.1109/TIP.2020.2967829 where the image assessment is either "global", i.e., a single number or an image based on the specified patch size. In the 3-D case, neighboring slices are used for each estimate. Note that parameters should be kept as consistent as possible in order to enable comparison. Patch size should be roughly 1/12th to 1/4th of image size to enable locality. A global estimate can be gained by setting patch_size = "global". Arguments --------- image : ANTsImage (2-D or 3-D) input image. mask : ANTsImage (2-D or 3-D) optional mask for designating calculation ROI. patch_size : integer prime number of patch_size. 101 is good. Otherwise, choose "global" for a single global estimate of quality. stride_length : integer or vector of image dimension length optional value to speed up computation (typically less than patch size). padding_size : positive or negative integer or vector of image dimension length de(padding) to remove edge effects. dimensions_to_predict : integer or vector if image dimension is 3, this parameter specifies which dimensions should be used for prediction. If more than one dimension is specified, the results are averaged. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to ~/.keras/ANTsXNet/. which_model : string model type e.g. string tidsQualityAssessment, koniqMS, koniqMS2 or koniqMS3 where the former predicts mean opinion score (MOS) and MOS standard deviation and the latter koniq models predict mean opinion score (MOS) and sharpness. verbose : boolean Print progress to the screen. Returns ------- List of QC results predicting both both human rater's mean and standard deviation of the MOS ("mean opinion scores") or sharpness depending on the selected network. Both aggregate and spatial scores are returned, the latter in the form of an image. Example ------- >>> image = ants.image_read(ants.get_data("r16")) >>> mask = ants.get_mask(image) >>> tid = tid_neural_image_assessment(image, mask=mask, patch_size=101, stride_length=7) """ from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches def is_prime(n): if n == 2 or n == 3: return True if n < 2 or n % 2 == 0: return False if n < 9: return True if n % 3 == 0: return False r = int(n**0.5) f = 5 while f <= r: if n % f == 0: return False if n % (f + 2) == 0: return False f += 6 return True valid_models = ("tidsQualityAssessment", "koniqMS", "koniqMS2", "koniqMS3") if not which_model in valid_models: raise ValueError("Please pass valid model") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" if verbose == True: print("Neural QA: retreiving model and weights.") is_koniq = "koniq" in which_model model_and_weights_file_name = get_pretrained_network( which_model, antsxnet_cache_directory=antsxnet_cache_directory) tid_model = tf.keras.models.load_model(model_and_weights_file_name, compile=False) padding_size_vector = padding_size if isinstance(padding_size, int): padding_size_vector = np.repeat(padding_size, image.dimension) elif len(padding_size) == 1: padding_size_vector = np.repeat(padding_size[0], image.dimension) if isinstance(dimensions_to_predict, int): dimensions_to_predict = (dimensions_to_predict, ) padded_image_size = image.shape + padding_size_vector padded_image = pad_or_crop_image_to_size(image, padded_image_size) number_of_channels = 3 if stride_length is None and patch_size != "global": stride_length = round(min(patch_size) / 2) if image.dimension == 3: stride_length = (stride_length, stride_length, 1) ############### # # Global # ############### if patch_size == 'global': if which_model == "tidsQualityAssessment": evaluation_image = ants.iMath(padded_image, "Normalize") * 255 if is_koniq: evaluation_image = ants.iMath(padded_image, "Normalize") * 2.0 - 1.0 if image.dimension == 2: batchX = np.zeros((1, evaluation_image.shape, number_of_channels)) for k in range(3): batchX[0, :, :, k] = evaluation_image.numpy() predicted_data = tid_model.predict(batchX, verbose=verbose) if which_model == "tidsQualityAssessment": return_dict = { 'MOS': None, 'MOS.standardDeviation': None, 'MOS.mean': predicted_data[0, 0], 'MOS.standardDeviationMean': predicted_data[0, 1] } return (return_dict) elif is_koniq: return_dict = { 'MOS.mean': predicted_data[0, 0], 'sharpness.mean': predicted_data[0, 1] } return (return_dict) elif image.dimension == 3: mos_mean = 0 mos_standard_deviation = 0 x = tuple(range(image.dimension)) for d in range(len(dimensions_to_predict)): not_padded_image_size = list(padded_image_size) del (not_padded_image_size[dimensions_to_predict[d]]) batchX = np.zeros( (padded_image_size[dimensions_to_predict[d]], (*tuple(not_padded_image_size)), number_of_channels)) batchX[0, :, :, 0] = (ants.slice_image( evaluation_image, axis=0, idx=dimensions_to_predict[d])).numpy() batchX[0, :, :, 1] = (ants.slice_image( evaluation_image, axis=0, idx=dimensions_to_predict[d])).numpy() batchX[0, :, :, 2] = (ants.slice_image( evaluation_image, axis=1, idx=dimensions_to_predict[d])).numpy() for i in range(1, padded_image_size[dimensions_to_predict[d]] - 1): batchX[i, :, :, 0] = (ants.slice_image( evaluation_image, axis=i - 1, idx=dimensions_to_predict[d])).numpy() batchX[i, :, :, 1] = (ants.slice_image( evaluation_image, axis=i, idx=dimensions_to_predict[d])).numpy() batchX[i, :, :, 2] = (ants.slice_image( evaluation_image, axis=i + 1, idx=dimensions_to_predict[d])).numpy() batchX[padded_image_size[dimensions_to_predict[d]], :, :, 0] = (ants.slice_image( evaluation_image, axis=padded_image_size[dimensions_to_predict[d]] - 1, idx=dimensions_to_predict[d])).numpy() batchX[padded_image_size[dimensions_to_predict[d]], :, :, 1] = (ants.slice_image( evaluation_image, axis=padded_image_size[dimensions_to_predict[d]], idx=dimensions_to_predict[d])).numpy() batchX[padded_image_size[dimensions_to_predict[d]], :, :, 2] = (ants.slice_image( evaluation_image, axis=padded_image_size[dimensions_to_predict[d]], idx=dimensions_to_predict[d])).numpy() predicted_data = tid_model.predict(batchX, verbose=verbose) mos_mean += predicted_data[0, 0] mos_standard_deviation += predicted_data[0, 1] mos_mean /= len(dimensions_to_predict) mos_standard_deviation /= len(dimensions_to_predict) if which_model == "tidsQualityAssessment": return_dict = { 'MOS.mean': mos_mean, 'MOS.standardDeviationMean': mos_standard_deviation } return (return_dict) elif is_koniq: return_dict = { 'MOS.mean': mos_mean, 'sharpness.mean': mos_standard_deviation } return (return_dict) ############### # # Patchwise # ############### else: evaluation_image = padded_image if not is_prime(patch_size): print( "patch_size should be a prime number: 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97..." ) stride_length_vector = stride_length if isinstance(stride_length, int): if image.dimension == 2: stride_length_vector = (stride_length, stride_length) elif len(stride_length) == 1: if image.dimension == 2: stride_length_vector = (stride_length[0], stride_length[0]) patch_size_vector = (patch_size, patch_size) if image.dimension == 2: dimensions_to_predict = (1, ) permutations = list() mos = image * 0 mos_standard_deviation = image * 0 for d in range(len(dimensions_to_predict)): if image.dimension == 3: permutations.append((0, 1, 2)) permutations.append((0, 2, 1)) permutations.append((1, 2, 0)) if dimensions_to_predict[d] == 0: patch_size_vector = (patch_size, patch_size, number_of_channels) if isinstance(stride_length, int): stride_length_vector = (stride_length, stride_length, 1) elif dimensions_to_predict[d] == 1: patch_size_vector = (patch_size, number_of_channels, patch_size) if isinstance(stride_length, int): stride_length_vector = (stride_length, 1, stride_length) elif dimensions_to_predict[d] == 2: patch_size_vector = (number_of_channels, patch_size, patch_size) if isinstance(stride_length, int): stride_length_vector = (1, stride_length, stride_length) else: raise ValueError( "dimensions_to_predict elements should be 1, 2, and/or 3 for 3-D image." ) patches = extract_image_patches(evaluation_image, patch_size=patch_size_vector, stride_length=stride_length_vector, return_as_array=False) batchX = np.zeros( (len(patches), patch_size, patch_size, number_of_channels)) is_good_patch = np.repeat(False, len(patches)) for i in range(len(patches)): if patches[i].var() > 0: is_good_patch[i] = True patch_image = patches[i] patch_image = patch_image - patch_image.min() if patch_image.max() > 0: if which_model == "tidsQualityAssessment": patch_image = patch_image / patch_image.max() * 255 elif is_koniq: patch_image = patch_image / patch_image.max( ) * 2.0 - 1.0 if image.dimension == 2: for j in range(number_of_channels): batchX[i, :, :, j] = patch_image elif image.dimension == 3: batchX[i, :, :, :] = np.transpose( np.squeeze(patch_image.numpy()), permutations[dimensions_to_predict[d]]) good_batchX = batchX[is_good_patch, :, :, :] predicted_data = tid_model.predict(good_batchX, verbose=verbose) patches_mos = list() patches_mos_standard_deviation = list() zero_patch_image = patch_image * 0 count = 0 for i in range(len(patches)): if is_good_patch[i]: patches_mos.append(zero_patch_image + predicted_data[count, 0]) patches_mos_standard_deviation.append(zero_patch_image + predicted_data[count, 1]) count += 1 else: patches_mos.append(zero_patch_image) patches_mos_standard_deviation.append(zero_patch_image) mos += pad_or_crop_image_to_size( reconstruct_image_from_patches( patches_mos, evaluation_image, stride_length=stride_length_vector), image.shape) mos_standard_deviation += pad_or_crop_image_to_size( reconstruct_image_from_patches( patches_mos_standard_deviation, evaluation_image, stride_length=stride_length_vector), image.shape) mos /= len(dimensions_to_predict) mos_standard_deviation /= len(dimensions_to_predict) if mask is None: if which_model == "tidsQualityAssessment": return_dict = { 'MOS': mos, 'MOS.standardDeviation': mos_standard_deviation, 'MOS.mean': mos.mean(), 'MOS.standardDeviationMean': mos_standard_deviation.mean() } return (return_dict) elif is_koniq: return_dict = { 'MOS': mos, 'sharpness': mos_standard_deviation, 'MOS.mean': mos.mean(), 'sharpness.mean': mos_standard_deviation.mean() } return (return_dict) else: if which_model == "tidsQualityAssessment": return_dict = { 'MOS': mos * mask, 'MOS.standardDeviation': mos_standard_deviation * mask, 'MOS.mean': (mos[mask >= 0.5]).mean(), 'MOS.standardDeviationMean': (mos_standard_deviation[mask >= 0.5]).mean() } return (return_dict) elif is_koniq: return_dict = { 'MOS': mos * mask, 'sharpness': mos_standard_deviation * mask, 'MOS.mean': (mos[mask >= 0.5]).mean(), 'sharpness.mean': (mos_standard_deviation[mask >= 0.5]).mean() } return (return_dict)
def sysu_media_wmh_segmentation(flair, t1=None, use_ensemble=True, antsxnet_cache_directory=None, verbose=False): """ Perform WMH segmentation using the winning submission in the MICCAI 2017 challenge by the sysu_media team using FLAIR or T1/FLAIR. The MICCAI challenge is discussed in https://pubmed.ncbi.nlm.nih.gov/30908194/ with the sysu_media's team entry is discussed in https://pubmed.ncbi.nlm.nih.gov/30125711/ with the original implementation available here: https://github.com/hongweilibran/wmh_ibbmTum The original implementation used global thresholding as a quick brain extraction approach. Due to possible generalization difficulties, we leave such post-processing steps to the user. For brain or white matter masking see functions brain_extraction or deep_atropos, respectively. Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). use_ensemble : boolean check whether to use all 3 sets of weights. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size from ..utilities import preprocess_brain_image from ..utilities import binary_dice_coefficient if flair.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_size = (200, 200) ################################ # # Preprocess images # ################################ def closest_simplified_direction_matrix(direction): closest = (np.abs(direction) + 0.5).astype(int).astype(float) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(flair.direction) flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] flair_preprocessed.set_direction(simplified_direction) flair_preprocessed.set_origin((0, 0, 0)) flair_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 1 t1_preprocessed = None if t1 is not None: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] t1_preprocessed.set_direction(simplified_direction) t1_preprocessed.set_origin((0, 0, 0)) t1_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 2 ################################ # # Reorient images # ################################ reference_image = ants.make_image((256, 256, 256), voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor( ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor( ants.get_center_of_mass(flair_preprocessed)) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) flair_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, flair_preprocessed, reference_image, interpolation="nearestneighbor") crop_image = ants.image_clone(flair_preprocessed) * 0 + 1 crop_image_warped = ants.apply_ants_transform_to_image( xfrm, crop_image, reference_image, interpolation="nearestneighbor") flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped, crop_image_warped, 1) if t1 is not None: t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image, interpolation="nearestneighbor") t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped, crop_image_warped, 1) ################################ # # Gaussian normalize intensity # ################################ mean_flair = flair_preprocessed.mean() std_flair = flair_preprocessed.std() if number_of_channels == 2: mean_t1 = t1_preprocessed.mean() std_t1 = t1_preprocessed.std() flair_preprocessed_warped = (flair_preprocessed_warped - mean_flair) / std_flair if number_of_channels == 2: t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 if verbose == True: print("White matter hyperintensity: retrieving model weights.") unet_models = list() for i in range(number_of_models): if number_of_channels == 1: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairOnlyModel" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairT1Model" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_model = create_sysu_media_unet_model_2d( (*image_size, number_of_channels)) unet_loss = binary_dice_coefficient(smoothing_factor=1.) unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4), loss=unet_loss) unet_model.load_weights(weights_file_name) unet_models.append(unet_model) ################################ # # Extract slices # ################################ dimensions_to_predict = [2] total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed_warped.shape[ dimensions_to_predict[d]] batchX = np.zeros( (total_number_of_slices, *image_size, number_of_channels)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size( ants.slice_image(flair_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count, :, :, 0] = flair_slice.numpy() if number_of_channels == 2: t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count, :, :, 1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_models[0].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction += unet_models[i].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) prediction /= number_of_models prediction = np.transpose(prediction, axes=(0, 2, 1, 3)) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed_warped.shape[ dimensions_to_predict[d]] which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( flair_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, flair_preprocessed) probability_image = ants.copy_image_info(flair, probability_image) return (probability_image)
def ew_david(flair, t1, do_preprocessing=True, which_model="sysu", which_axes=2, number_of_simulations=0, sd_affine=0.01, antsxnet_cache_directory=None, verbose=False): """ Perform White matter hyperintensity probabilistic segmentation using deep learning Preprocessing on the training data consisted of: * n4 bias correction, * intensity truncation, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set \code{do_preprocessing = True} Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). do_preprocessing : boolean perform n4 bias correction, intensity truncation, brain extraction. which_model : string one of: * "sysu" -- same as the original sysu network (without site specific preprocessing), * "sysu-ri" -- same as "sysu" but using ranked intensity scaling for input images, * "sysuWithAttention" -- "sysu" with attention gating, * "sysuWithAttentionAndSite" -- "sysu" with attention gating with site branch (see "sysuWithSite"), * "sysuPlus" -- "sysu" with attention gating and nn-Unet activation, * "sysuPlusSeg" -- "sysuPlus" with deep_atropos segmentation in an additional channel, and * "sysuWithSite" -- "sysu" with global pooling on encoding channels to predict "site". * "sysuPlusSegWithSite" -- "sysuPlusSeg" combined with "sysuWithSite" In addition to both modalities, all models have T1-only and flair-only variants except for "sysuPlusSeg" (which only has a T1-only variant) or "sysu-ri" (which has neither single modality variant). which_axes : string or scalar or tuple/vector apply 2-D model to 1 or more axes. In addition to a scalar or vector, e.g., which_axes = (0, 2), one can use "max" for the axis with maximum anisotropy (default) or "all" for all axes. number_of_simulations : integer Number of random affine perturbations to transform the input. sd_affine : float Define the standard deviation of the affine transformation parameter. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_unet_model_2d from ..utilities import deep_atropos from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import randomly_transform_image_data from ..utilities import pad_or_crop_image_to_size do_t1_only = False do_flair_only = False if flair is None and t1 is not None: do_t1_only = True elif flair is not None and t1 is None: do_flair_only = True use_t1_segmentation = False if "Seg" in which_model: if do_flair_only: raise ValueError("Segmentation requires T1.") else: use_t1_segmentation = True if use_t1_segmentation and do_preprocessing == False: raise ValueError( "Using the t1 segmentation requires do_preprocessing=True.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" do_slicewise = True if do_slicewise == False: raise ValueError("Not available.") # ################################ # # # # Preprocess images # # # ################################ # t1_preprocessed = t1 # t1_preprocessing = None # if do_preprocessing == True: # t1_preprocessing = preprocess_brain_image(t1, # truncate_intensity=(0.01, 0.99), # brain_extraction_modality="t1", # template="croppedMni152", # template_transform_type="antsRegistrationSyNQuickRepro[a]", # do_bias_correction=True, # do_denoising=False, # antsxnet_cache_directory=antsxnet_cache_directory, # verbose=verbose) # t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] # flair_preprocessed = flair # if do_preprocessing == True: # flair_preprocessing = preprocess_brain_image(flair, # truncate_intensity=(0.01, 0.99), # brain_extraction_modality="t1", # do_bias_correction=True, # do_denoising=False, # antsxnet_cache_directory=antsxnet_cache_directory, # verbose=verbose) # flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed, # moving=flair_preprocessing["preprocessed_image"], # transformlist=t1_preprocessing['template_transforms']['fwdtransforms']) # flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask'] # ################################ # # # # Build model and load weights # # # ################################ # patch_size = (112, 112, 112) # stride_length = (t1_preprocessed.shape[0] - patch_size[0], # t1_preprocessed.shape[1] - patch_size[1], # t1_preprocessed.shape[2] - patch_size[2]) # classes = ("background", "wmh" ) # number_of_classification_labels = len(classes) # labels = (0, 1) # image_modalities = ("T1", "FLAIR") # channel_size = len(image_modalities) # unet_model = create_unet_model_3d((*patch_size, channel_size), # number_of_outputs = number_of_classification_labels, # number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0, # convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), # weight_decay = 1e-5, additional_options=("attentionGating")) # weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights", # antsxnet_cache_directory=antsxnet_cache_directory) # unet_model.load_weights(weights_file_name) # ################################ # # # # Do prediction and normalize to native space # # # ################################ # if verbose == True: # print("ew_david: prediction.") # batchX = np.zeros((8, *patch_size, channel_size)) # t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() # t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, # max_number_of_patches="all", stride_length=stride_length, # return_as_array=True) # batchX[:,:,:,:,0] = t1_patches # flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std() # flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size, # max_number_of_patches="all", stride_length=stride_length, # return_as_array=True) # batchX[:,:,:,:,1] = flair_patches # predicted_data = unet_model.predict(batchX, verbose=verbose) # probability_images = list() # for i in range(len(labels)): # print("Reconstructing image", classes[i]) # reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i], # domain_image=t1_preprocessed, stride_length=stride_length) # if do_preprocessing == True: # probability_images.append(ants.apply_transforms(fixed=t1, # moving=reconstructed_image, # transformlist=t1_preprocessing['template_transforms']['invtransforms'], # whichtoinvert=[True], interpolator="linear", verbose=verbose)) # else: # probability_images.append(reconstructed_image) # return(probability_images[1]) else: # do_slicewise ################################ # # Preprocess images # ################################ use_rank_intensity_scaling = False if "-ri" in which_model: use_rank_intensity_scaling = True t1_preprocessed = None t1_preprocessing = None brain_mask = None if t1 is not None: if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.995), brain_extraction_modality="t1", do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = ants.threshold_image( t1_preprocessing["brain_mask"], 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessing["preprocessed_image"] t1_segmentation = None if use_t1_segmentation: atropos_seg = deep_atropos(t1, do_preprocessing=True, verbose=verbose) t1_segmentation = atropos_seg['segmentation_image'] flair_preprocessed = None if flair is not None: flair_preprocessed = flair if do_preprocessing == True: if brain_mask is None: flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=(0.01, 0.995), brain_extraction_modality="flair", do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = ants.threshold_image( flair_preprocessing["brain_mask"], 0.5, 1, 1, 0) else: flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] if t1_preprocessed is not None: t1_preprocessed = t1_preprocessed * brain_mask if flair_preprocessed is not None: flair_preprocessed = flair_preprocessed * brain_mask if t1_preprocessed is not None: resampling_params = list(ants.get_spacing(t1_preprocessed)) else: resampling_params = list(ants.get_spacing(flair_preprocessed)) do_resampling = False for d in range(len(resampling_params)): if resampling_params[d] < 0.8: resampling_params[d] = 1.0 do_resampling = True resampling_params = tuple(resampling_params) if do_resampling: if flair_preprocessed is not None: flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0) if t1_preprocessed is not None: t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0) if t1_segmentation is not None: t1_segmentation = ants.resample_image(t1_segmentation, resampling_params, use_voxels=False, interp_type=1) if brain_mask is not None: brain_mask = ants.resample_image(brain_mask, resampling_params, use_voxels=False, interp_type=1) ################################ # # Build model and load weights # ################################ template_size = (208, 208) image_modalities = ("T1", "FLAIR") if do_flair_only: image_modalities = ("FLAIR", ) elif do_t1_only: image_modalities = ("T1", ) if use_t1_segmentation: image_modalities = (*image_modalities, "T1Seg") channel_size = len(image_modalities) unet_model = None if which_model == "sysu" or which_model == "sysu-ri": unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("initialConvolutionKernelSize[5]", )) elif which_model == "sysuWithAttention": unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("attentionGating", "initialConvolutionKernelSize[5]")) elif which_model == "sysuWithAttentionAndSite": unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", scalar_output_size=3, scalar_output_activation="softmax", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("attentionGating", "initialConvolutionKernelSize[5]")) elif which_model == "sysuWithSite": unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", scalar_output_size=3, scalar_output_activation="softmax", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("initialConvolutionKernelSize[5]", )) elif which_model == "sysuPlusSegWithSite": unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", scalar_output_size=3, scalar_output_activation="softmax", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("nnUnetActivationStyle", "attentionGating", "initialConvolutionKernelSize[5]")) else: unet_model = create_unet_model_2d( (*template_size, channel_size), number_of_outputs=1, mode="sigmoid", number_of_filters=(64, 96, 128, 256, 512), dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0, additional_options=("nnUnetActivationStyle", "attentionGating", "initialConvolutionKernelSize[5]")) if verbose == True: print("ewDavid: retrieving model weights.") weights_file_name = None if which_model == "sysu" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysu", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysu-ri" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuRankedIntensity", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysu" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysu" and flair is not None and t1 is None: weights_file_name = get_pretrained_network( "ewDavidSysuFlairOnly", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttention" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttention", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttention" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttentionT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttention" and flair is not None and t1 is None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttentionFlairOnly", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttentionAndSite", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttentionAndSite" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttentionAndSiteT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is None: weights_file_name = get_pretrained_network( "ewDavidSysuWithAttentionAndSiteFlairOnly", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlus" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlus", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlus" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlus" and flair is not None and t1 is None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusFlairOnly", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlusSeg" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusSeg", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlusSeg" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusSegT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlusSegWithSite" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusSegWithSite", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuPlusSegWithSite" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuPlusSegWithSiteT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithSite" and flair is not None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithSite", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithSite" and flair is None and t1 is not None: weights_file_name = get_pretrained_network( "ewDavidSysuWithSiteT1Only", antsxnet_cache_directory=antsxnet_cache_directory) elif which_model == "sysuWithSite" and flair is not None and t1 is None: weights_file_name = get_pretrained_network( "ewDavidSysuWithSiteFlairOnly", antsxnet_cache_directory=antsxnet_cache_directory) else: raise ValueError( "Incorrect model specification or image combination.") unet_model.load_weights(weights_file_name) ################################ # # Data augmentation and extract slices # ################################ wmh_probability_image = None if t1 is not None: wmh_probability_image = ants.image_clone(t1_preprocessed) * 0 else: wmh_probability_image = ants.image_clone(flair_preprocessed) * 0 wmh_site = np.array([0, 0, 0]) data_augmentation = None if number_of_simulations > 0: if do_flair_only: data_augmentation = randomly_transform_image_data( reference_image=flair_preprocessed, input_image_list=[[flair_preprocessed]], number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear') elif do_t1_only: if use_t1_segmentation: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=[[t1_preprocessed]], segmentation_image_list=[t1_segmentation], number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear', segmentation_image_interpolator='nearestNeighbor') else: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=[[t1_preprocessed]], number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear') else: if use_t1_segmentation: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=[[ flair_preprocessed, t1_preprocessed ]], segmentation_image_list=[t1_segmentation], number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear', segmentation_image_interpolator='nearestNeighbor') else: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=[[ flair_preprocessed, t1_preprocessed ]], number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear') dimensions_to_predict = list((0, )) if which_axes == "max": spacing = ants.get_spacing(wmh_probability_image) dimensions_to_predict = (spacing.index(max(spacing)), ) elif which_axes == "all": dimensions_to_predict = list(range(3)) else: if isinstance(which_axes, int): dimensions_to_predict = list((which_axes, )) else: dimensions_to_predict = list(which_axes) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += wmh_probability_image.shape[ dimensions_to_predict[d]] batchX = np.zeros( (total_number_of_slices, *template_size, channel_size)) for n in range(number_of_simulations + 1): batch_flair = flair_preprocessed batch_t1 = t1_preprocessed batch_t1_segmentation = t1_segmentation batch_brain_mask = brain_mask if n > 0: if do_flair_only: batch_flair = data_augmentation['simulated_images'][n - 1][0] batch_brain_mask = ants.apply_ants_transform_to_image( data_augmentation['simulated_transforms'][n - 1], brain_mask, flair_preprocessed, interpolation="nearestneighbor") elif do_t1_only: batch_t1 = data_augmentation['simulated_images'][n - 1][0] batch_brain_mask = ants.apply_ants_transform_to_image( data_augmentation['simulated_transforms'][n - 1], brain_mask, t1_preprocessed, interpolation="nearestneighbor") else: batch_flair = data_augmentation['simulated_images'][n - 1][0] batch_t1 = data_augmentation['simulated_images'][n - 1][1] batch_brain_mask = ants.apply_ants_transform_to_image( data_augmentation['simulated_transforms'][n - 1], brain_mask, flair_preprocessed, interpolation="nearestneighbor") if use_t1_segmentation: batch_t1_segmentation = data_augmentation[ 'simulated_segmentation_images'][n - 1] if use_rank_intensity_scaling: if batch_t1 is not None: batch_t1 = ants.rank_intensity(batch_t1, batch_brain_mask) - 0.5 if batch_flair is not None: batch_flair = ants.rank_intensity(flair_preprocessed, batch_brain_mask) - 0.5 else: if batch_t1 is not None: batch_t1 = (batch_t1 - batch_t1[batch_brain_mask == 1].mean() ) / batch_t1[batch_brain_mask == 1].std() if batch_flair is not None: batch_flair = ( batch_flair - batch_flair[batch_brain_mask == 1].mean() ) / batch_flair[batch_brain_mask == 1].std() slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = None if batch_t1 is not None: number_of_slices = batch_t1.shape[dimensions_to_predict[d]] else: number_of_slices = batch_flair.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d]) for i in range(number_of_slices): brain_mask_slice = pad_or_crop_image_to_size( ants.slice_image(batch_brain_mask, dimensions_to_predict[d], i), template_size) channel_count = 0 if batch_flair is not None: flair_slice = pad_or_crop_image_to_size( ants.slice_image(batch_flair, dimensions_to_predict[d], i), template_size) flair_slice[brain_mask_slice == 0] = 0 batchX[slice_count, :, :, channel_count] = flair_slice.numpy() channel_count += 1 if batch_t1 is not None: t1_slice = pad_or_crop_image_to_size( ants.slice_image(batch_t1, dimensions_to_predict[d], i), template_size) t1_slice[brain_mask_slice == 0] = 0 batchX[slice_count, :, :, channel_count] = t1_slice.numpy() channel_count += 1 if t1_segmentation is not None: t1_segmentation_slice = pad_or_crop_image_to_size( ants.slice_image(batch_t1_segmentation, dimensions_to_predict[d], i), template_size) t1_segmentation_slice[brain_mask_slice == 0] = 0 batchX[slice_count, :, :, channel_count] = t1_segmentation_slice.numpy( ) / 6 - 0.5 slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: if n == 0: print("Prediction") else: print("Prediction (simulation " + str(n) + ")") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone( wmh_probability_image) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + wmh_probability_image.shape[ dimensions_to_predict[d]] which_batch_slices = range(current_start_slice, current_end_slice) if isinstance(prediction, list): prediction_per_dimension = prediction[0][ which_batch_slices, :, :, 0] else: prediction_per_dimension = prediction[ which_batch_slices, :, :, 0] prediction_array = np.transpose( np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( wmh_probability_image, pad_or_crop_image_to_size( ants.from_numpy(prediction_array), wmh_probability_image.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice wmh_probability_image = wmh_probability_image + ( prediction_image_average - wmh_probability_image) / (n + 1) if isinstance(prediction, list): wmh_site = wmh_site + (np.mean(prediction[1], axis=0) - wmh_site) / (n + 1) if do_resampling: if t1 is not None: wmh_probability_image = ants.resample_image_to_target( wmh_probability_image, t1) if flair is not None: wmh_probability_image = ants.resample_image_to_target( wmh_probability_image, flair) if isinstance(prediction, list): return ([wmh_probability_image, wmh_site]) else: return (wmh_probability_image)
import ants import os import numpy as np root = '/gsfs0/data/poskanzc/Sherlock/preproc/fmriprep/' all_subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-10', 'sub-11', 'sub-12', 'sub-13', 'sub-14', 'sub-15', 'sub-16'] mask = '_CSF_WM_mask_union_bin_shrinked.nii.gz' for sub in all_subjects: sub_dir = root + sub +'/func/' mask_dir = root + sub +'/'+ sub +'_ROIs/' +sub + mask sub_out_dir = root + sub + '/' + sub + '_ROIs/' epi = ants.image_read(os.path.join(sub_dir,sub + '_task-sherlockPart1_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz')) epi3d = ants.slice_image(epi,axis=3,idx=0) noiseROI = ants.image_read(os.path.join(mask_dir)) reg = ants.registration(fixed=epi3d, moving=noiseROI,type_of_transform='Rigid', reg_iterations = [100,100,20] ) rAnat = reg['warpedmovout'] rAnat.to_filename(os.path.join(sub_out_dir, sub + '_CSF_WM_mask_union_shrinked_funcSize.nii.gz'))