def test_get_center_of_mass(self): fi = ants.image_read(ants.get_ants_data("r16")) com = ants.get_center_of_mass(fi) self.assertEqual(len(com), fi.dimension) fi = ants.image_read(ants.get_ants_data("r64")) com = ants.get_center_of_mass(fi) self.assertEqual(len(com), fi.dimension) fi = fi.clone("unsigned int") com = ants.get_center_of_mass(fi) self.assertEqual(len(com), fi.dimension) # 3d img = ants.image_read(ants.get_ants_data("mni")) com = ants.get_center_of_mass(img) self.assertEqual(len(com), img.dimension)
def lung_extraction(image, modality="proton", antsxnet_cache_directory=None, verbose=False): """ Perform proton or ct lung extraction using U-net. Arguments --------- image : ANTsImage input image modality : string Modality image type. Options include "ct", "proton", "protonLobes", "maskLobes", and "ventilation". antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Dictionary of ANTs segmentation and probability images. Example ------- >>> output = lung_extraction(lung_image, modality="proton") """ from ..architectures import create_unet_model_2d from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import pad_or_crop_image_to_size if image.dimension != 3: raise ValueError( "Image dimension must be 3." ) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_mods = [modality] channel_size = len(image_mods) weights_file_name = None unet_model = None if modality == "proton": weights_file_name = get_pretrained_network("protonLungMri", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left_lung", "right_lung") number_of_classification_labels = len(classes) reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5)) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) batchX = np.expand_dims(warped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability mask to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) if modality == "protonLobes" or modality == "maskLobes": reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) channel_size = 1 + len(priors_image_list) number_of_classification_labels = 1 + len(priors_image_list) unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters_at_base_layer=16, number_of_layers=4, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",)) if modality == "protonLobes": penultimate_layer = unet_model.layers[-2].output outputs2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2]) weights_file_name = get_pretrained_network("protonLobes", antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network("maskLobes", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) warped_array = warped_image.numpy() if modality == "protonLobes": warped_array = (warped_array - warped_array.mean()) / warped_array.std() else: warped_array[warped_array != 0] = 1 batchX = np.zeros((1, *warped_array.shape, channel_size)) batchX[0,:,:,:,0] = warped_array for i in range(len(priors_image_list)): batchX[0,:,:,:,i+1] = priors_image_list[i].numpy() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): if modality == "protonLobes": probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) else: probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability images to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] if modality == "protonLobes": whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]), origin=origin, spacing=spacing, direction=direction) whole_lung_mask = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), whole_lung_mask, image) return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array, 'whole_lung_mask_image' : whole_lung_mask} return(return_dict) else: return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) elif modality == "ct": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess CT image.") def closest_simplified_direction_matrix(direction): closest = np.floor(np.abs(direction) + 0.5) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(image.direction) reference_image_size = (128, 128, 128) ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0) ct_preprocessed[ct_preprocessed < -1000] = -1000 ct_preprocessed[ct_preprocessed > 400] = 400 ct_preprocessed.set_direction(simplified_direction) ct_preprocessed.set_origin((0, 0, 0)) ct_preprocessed.set_spacing((1, 1, 1)) ################################ # # Reorient image # ################################ reference_image = ants.make_image(reference_image_size, voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1)) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) / (ct_preprocessed.max() - ct_preprocessed.min())) ct_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor") ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) / (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5 ################################ # # Build models and load weights # ################################ if verbose == True: print("Build model and load weights.") weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left lung", "right lung", "airways") number_of_classification_labels = len(classes) luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors"))) for i in range(len(luna16_priors)): luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True) channel_size = len(luna16_priors) + 1 unet_model = create_unet_model_3d((*reference_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating",)) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") batchX = np.zeros((1, *reference_image_size, channel_size)) batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy() for i in range(len(luna16_priors)): batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5 predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(number_of_classification_labels): if verbose == True: print("Reconstructing image", classes[i]) probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]), origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing, direction=ct_preprocessed_warped.direction) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed) probability_image = ants.resample_image(probability_image, resample_params=image.shape, use_voxels=True, interp_type=0) probability_image = ants.copy_image_info(image, probability_image) probability_images.append(probability_image) image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict) elif modality == "ventilation": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess ventilation image.") template_size = (256, 256) image_modalities = ("Ventilation",) channel_size = len(image_modalities) preprocessed_image = (image - image.mean()) / image.std() ants.set_direction(preprocessed_image, np.identity(3)) ################################ # # Build models and load weights # ################################ unet_model = create_unet_model_2d((*template_size, channel_size), number_of_outputs=1, mode='sigmoid', number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0) if verbose == True: print("Whole lung mask: retrieving model weights.") weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ spacing = ants.get_spacing(preprocessed_image) dimensions_to_predict = (spacing.index(max(spacing)),) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,0] = ventilation_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) probability_image = ants.image_clone(image) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices,:,:,0] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(image, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), image.shape)) probability_image = probability_image + (prediction_image - probability_image) / (d + 1) current_start_slice = current_end_slice + 1 return(probability_image) else: return ValueError("Unrecognized modality.")
def sysu_media_wmh_segmentation(flair, t1=None, do_preprocessing=True, use_ensemble=True, use_axial_slices_only=True, antsxnet_cache_directory=None, verbose=False): """ Perform WMH segmentation using the winning submission in the MICCAI 2017 challenge by the sysu_media team using FLAIR or T1/FLAIR. The MICCAI challenge is discussed in https://pubmed.ncbi.nlm.nih.gov/30908194/ with the sysu_media's team entry is discussed in https://pubmed.ncbi.nlm.nih.gov/30125711/ with the original implementation available here: https://github.com/hongweilibran/wmh_ibbmTum Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). do_preprocessing : boolean perform n4 bias correction? use_ensemble : boolean check whether to use all 3 sets of weights. use_axial_slices_only : boolean If True, use original implementation which was trained on axial slices. If False, use ANTsXNet variant implementation which applies the slice-by-slice models to all 3 dimensions and averages the results. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import brain_extraction from ..utilities import crop_image_center from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size if flair.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ flair_preprocessed = flair if do_preprocessing == True: flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] number_of_channels = 1 if t1 is not None: t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), do_brain_extraction=False, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] number_of_channels = 2 ################################ # # Estimate mask # ################################ brain_mask = None if verbose == True: print("Estimating brain mask.") if t1 is not None: brain_mask = brain_extraction(t1, modality="t1") else: brain_mask = brain_extraction(flair, modality="flair") reference_image = ants.make_image((200, 200, 200), voxval=1, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = ants.get_center_of_mass(reference_image) center_of_mass_image = ants.get_center_of_mass(brain_mask) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) flair_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, flair_preprocessed, reference_image) brain_mask_warped = ants.threshold_image( ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image), 0.5, 1.1, 1, 0) if t1 is not None: t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image) ################################ # # Gaussian normalize intensity based on brain mask # ################################ mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean() std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std() flair_preprocessed_warped = (flair_preprocessed_warped - mean_flair) / std_flair if number_of_channels == 2: mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean() std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std() t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 unet_models = list() for i in range(number_of_models): if number_of_channels == 1: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairOnlyModel" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairT1Model" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_models.append( create_sysu_media_unet_model_2d((200, 200, number_of_channels))) unet_models[i].load_weights(weights_file_name) ################################ # # Extract slices # ################################ dimensions_to_predict = [2] if use_axial_slices_only == False: dimensions_to_predict = list(range(3)) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed_warped.shape[ dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size( ants.slice_image(flair_preprocessed_warped, dimensions_to_predict[d], i), (200, 200)) batchX[slice_count, :, :, 0] = flair_slice.numpy() if number_of_channels == 2: t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), (200, 200)) batchX[slice_count, :, :, 1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_models[0].predict(batchX, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction += unet_models[i].predict(batchX, verbose=verbose) prediction /= number_of_models permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed_warped.shape[ dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( flair_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice + 1 probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, flair) return (probability_image)
start_time = time.time() image = ants.image_read(input_file_name) end_time = time.time() elapsed_time = end_time - start_time print(" (elapsed time: ", elapsed_time, " seconds)") print("Reading reorientation template " + reorient_template_file_name) start_time = time.time() reorient_template = ants.image_read(reorient_template_file_name) end_time = time.time() elapsed_time = end_time - start_time print(" (elapsed time: ", elapsed_time, " seconds)") print("Normalizing to template") start_time = time.time() center_of_mass_template = ants.get_center_of_mass(reorient_template) center_of_mass_image = ants.get_center_of_mass(image) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) warped_image = (warped_image - warped_image.mean()) / warped_image.std() ######################################### # # Perform initial (stage 1) segmentation #
def lung_extraction(image, modality="proton", output_directory=None, verbose=None): """ Perform proton or ct lung extraction using U-net. Arguments --------- image : ANTsImage input image modality : string Modality image type. Options include "ct" and "proton". output_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a tempfile. verbose : boolean Print progress to the screen. Returns ------- Dictionary of ANTs segmentation and probability images. Example ------- >>> output = lung_extraction(lung_image, modality="proton") """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network if image.dimension != 3: raise ValueError("Image dimension must be 3.") image_mods = [modality] channel_size = len(image_mods) weights_file_name = None unet_model = None if modality == "proton": if output_directory is not None: weights_file_name = output_directory + "/protonLungSegmentationWeights.h5" if not os.path.exists(weights_file_name): if verbose == True: print("Lung extraction: downloading weights.") weights_file_name = get_pretrained_network( "protonLungMri", weights_file_name) else: weights_file_name = get_pretrained_network("protonLungMri") classes = ("background", "left_lung", "right_lung") number_of_classification_labels = len(classes) reorient_template_file_name = None reorient_template_file_exists = False if output_directory is not None: reorient_template_file_name = output_directory + "/protonLungTemplate.nii.gz" if os.path.exists(reorient_template_file_name): reorient_template_file_exists = True reorient_template = None if output_directory is None or reorient_template_file_exists == False: reorient_template_file = tempfile.NamedTemporaryFile( suffix=".nii.gz") reorient_template_file.close() template_file_name = reorient_template_file.name template_url = "https://ndownloader.figshare.com/files/22707338" if not os.path.exists(template_file_name): if verbose == True: print("Lung extraction: downloading template.") r = requests.get(template_url) with open(template_file_name, 'wb') as f: f.write(r.content) reorient_template = ants.image_read(template_file_name) if output_directory is not None: shutil.copy(template_file_name, reorient_template_file_name) else: reorient_template = ants.image_read(reorient_template_file_name) resampled_image_size = reorient_template.shape unet_model = create_unet_model_3d( (*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5)) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_template) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image( xfrm, image, reorient_template) batchX = np.expand_dims(warped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=0) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print( "Lung extraction: renormalize probability mask to native space." ) for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = { 'segmentation_image': segmentation_image, 'probability_images': probability_images_array } return (return_dict) elif modality == "ct": if output_directory is not None: weights_file_name = output_directory + "/ctLungSegmentationWeights.h5" if not os.path.exists(weights_file_name): if verbose == True: print("Lung extraction: downloading weights.") weights_file_name = get_pretrained_network( "ctHumanLung", weights_file_name) else: weights_file_name = get_pretrained_network("ctHumanLung") classes = ("background", "left_lung", "right_lung", "trachea") number_of_classification_labels = len(classes) reorient_template_file_name = None reorient_template_file_exists = False if output_directory is not None: reorient_template_file_name = output_directory + "/ctLungTemplate.nii.gz" if os.path.exists(reorient_template_file_name): reorient_template_file_exists = True reorient_template = None if output_directory is None or reorient_template_file_exists == False: reorient_template_file = tempfile.NamedTemporaryFile( suffix=".nii.gz") reorient_template_file.close() template_file_name = reorient_template_file.name template_url = "https://ndownloader.figshare.com/files/22707335" if not os.path.exists(template_file_name): if verbose == True: print("Lung extraction: downloading template.") r = requests.get(template_url) with open(template_file_name, 'wb') as f: f.write(r.content) reorient_template = ants.image_read(template_file_name) if output_directory is not None: shutil.copy(template_file_name, reorient_template_file_name) else: reorient_template = ants.image_read(reorient_template_file_name) resampled_image_size = reorient_template.shape unet_model = create_unet_model_3d( (*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2)) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_template) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image( xfrm, image, reorient_template) batchX = np.expand_dims(warped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=0) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print( "Lung extraction: renormalize probability mask to native space." ) for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = { 'segmentation_image': segmentation_image, 'probability_images': probability_images_array } return (return_dict)
def randomly_transform_image_data( reference_image, input_image_list, segmentation_image_list=None, number_of_simulations=10, transform_type='affine', sd_affine=0.02, deformation_transform_type="bspline", number_of_random_points=1000, sd_noise=10.0, number_of_fitting_levels=4, mesh_size=1, sd_smoothing=4.0, input_image_interpolator='linear', segmentation_image_interpolator='nearestNeighbor'): """ Randomly transform image data (optional: with corresponding segmentations). Apply rigid, affine and/or deformable maps to an input set of training images. The reference image domain defines the space in which this happens. Arguments --------- reference_image : ANTsImage Defines the spatial domain for all output images. If the input images do not match the spatial domain of the reference image, we internally resample the target to the reference image. This could have unexpected consequences. Resampling to the reference domain is performed by testing using ants.image_physical_space_consistency then calling ants.resample_image_to_target with failure. input_image_list : list of lists of ANTsImages List of lists of input images to warp. The internal list sets contain one or more images (per subject) which are assumed to be mutually aligned. The outer list contains multiple subject lists which are randomly sampled to produce output image list. segmentation_image_list : list of ANTsImages List of segmentation images corresponding to the input image list (optional). number_of_simulations : integer Number of output images. transform_type : string One of the following options: "translation", "rigid", "scaleShear", "affine", "deformation", "affineAndDeformation". sd_affine : float Parameter dictating deviation amount from identity for random linear transformations. deformation_transform_type : string "bspline" or "exponential". number_of_random_points : integer Number of displacement points for the deformation field. sd_noise : float Standard deviation of the displacement field. number_of_fitting_levels : integer Number of fitting levels (bspline deformation only). mesh_size : int or n-D tuple Determines fitting resolution (bspline deformation only). sd_smoothing : float Standard deviation of the Gaussian smoothing in mm (exponential field only). input_image_interpolator : string One of the following options "linear", "gaussian", "bspline". segmentation_image_interpolator : string One of the following options "nearestNeighbor" or "genericLabel". Returns ------- list of lists of transformed images Example ------- >>> import ants >>> image1_list = list() >>> image1_list.append(ants.image_read(ants.get_ants_data("r16"))) >>> image2_list = list() >>> image2_list.append(ants.image_read(ants.get_ants_data("r64"))) >>> input_segmentations = list() >>> input_segmentations.append(ants.threshold_image(image1, "Otsu", 3)) >>> input_segmentations.append(ants.threshold_image(image2, "Otsu", 3)) >>> input_images = list() >>> input_images.append(image1_list) >>> input_images.append(image2_list) >>> data = antspynet.randomly_transform_image_data(image1, >>> input_images, input_segmentations, sd_affine=0.02, >>> transform_type = "affineAndDeformation" ) """ def polar_decomposition(X): U, d, V = np.linalg.svd(X, full_matrices=False) P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U))) Z = np.matmul(U, np.transpose(V)) if np.linalg.det(Z) < 0: Z = -Z return ({"P": P, "Z": Z, "Xtilde": np.matmul(P, Z)}) def create_random_linear_transform(image, fixed_parameters, transform_type='affine', sd_affine=1.0): transform = ants.create_ants_transform( transform_type="AffineTransform", precision='float', dimension=image.dimension) ants.set_ants_transform_fixed_parameters(transform, fixed_parameters) identity_parameters = ants.get_ants_transform_parameters(transform) random_epsilon = np.random.normal(loc=0, scale=sd_affine, size=len(identity_parameters)) if transform_type == 'translation': random_epsilon[:(len(identity_parameters) - image.dimension)] = 0 random_parameters = identity_parameters + random_epsilon random_matrix = np.reshape( random_parameters[:(len(identity_parameters) - image.dimension)], newshape=(image.dimension, image.dimension)) decomposition = polar_decomposition(random_matrix) if transform_type == "rigid": random_matrix = decomposition['Z'] elif transform_type == "affine": random_matrix = decomposition['Xtilde'] elif transform_type == "scaleShear": random_matrix = decomposition['P'] random_parameters[:(len(identity_parameters) - image.dimension)] = \ np.reshape(random_matrix, newshape=(len(identity_parameters) - image.dimension)) ants.set_ants_transform_parameters(transform, random_parameters) return (transform) def create_random_displacement_field_transform( image, field_type="bspline", number_of_random_points=1000, sd_noise=10.0, number_of_fitting_levels=4, mesh_size=1, sd_smoothing=4.0): displacement_field = ants.simulate_displacement_field( image, field_type=field_type, number_of_random_points=number_of_random_points, sd_noise=sd_noise, enforce_stationary_boundary=True, number_of_fitting_levels=number_of_fitting_levels, mesh_size=mesh_size, sd_smoothing=sd_smoothing) return (ants.transform_from_displacement_field(displacement_field)) admissible_transforms = ("translation", "rigid", "scaleShear", "affine", "affineAndDeformation", "deformation") if not transform_type in admissible_transforms: raise ValueError( "The specified transform is not a possible option. Please see help menu." ) # Get the fixed parameters from the reference image. fixed_parameters = ants.get_center_of_mass(reference_image) number_of_subjects = len(input_image_list) random_indices = np.random.choice(number_of_subjects, size=number_of_simulations, replace=True) simulated_image_list = list() simulated_segmentation_image_list = list() simulated_transforms = list() for i in range(number_of_simulations): single_subject_image_list = input_image_list[random_indices[i]] single_subject_segmentation_image = None if segmentation_image_list is not None: single_subject_segmentation_image = segmentation_image_list[ random_indices[i]] if ants.image_physical_space_consistency( reference_image, single_subject_image_list[0]) is False: for j in range(len(single_subject_image_list)): single_subject_image_list.append( ants.resample_image_to_target( single_subject_image_list[j], reference_image, interp_type=input_image_interpolator)) if single_subject_segmentation_image is not None: single_subject_segmentation_image = \ ants.resample_image_to_target(single_subject_segmentation_image, reference_image, interp_type=segmentation_image_interpolator) transforms = list() if transform_type == 'deformation': deformable_transform = create_random_displacement_field_transform( reference_image, deformation_transform_type, number_of_random_points, sd_noise, number_of_fitting_levels, mesh_size, sd_smoothing) transforms.append(deformable_transform) elif transform_type == 'affineAndDeformation': deformable_transform = create_random_displacement_field_transform( reference_image, deformation_transform_type, number_of_random_points, sd_noise, number_of_fitting_levels, mesh_size, sd_smoothing) linear_transform = create_random_linear_transform( reference_image, fixed_parameters, 'affine', sd_affine) transforms.append(deformable_transform) transforms.append(linear_transform) else: linear_transform = create_random_linear_transform( reference_image, fixed_parameters, transform_type, sd_affine) transforms.append(linear_transform) simulated_transforms.append(ants.compose_ants_transforms(transforms)) single_subject_simulated_image_list = list() for j in range(len(single_subject_image_list)): single_subject_simulated_image_list.append( ants.apply_ants_transform_to_image( simulated_transforms[i], single_subject_image_list[j], reference=reference_image)) simulated_image_list.append(single_subject_simulated_image_list) if single_subject_segmentation_image is not None: simulated_segmentation_image_list.append( ants.apply_ants_transform_to_image( simulated_transforms[i], single_subject_segmentation_image, reference=reference_image)) if segmentation_image_list is None: return ({ 'simulated_images': simulated_image_list, 'simulated_transforms': simulated_transforms }) else: return ({ 'simulated_images': simulated_image_list, 'simulated_segmentation_images': simulated_segmentation_image_list, 'simulated_transforms': simulated_transforms })
def claustrum_segmentation(t1, do_preprocessing=True, use_ensemble=True, antsxnet_cache_directory=None, verbose=False): """ Claustrum segmentation Described here: https://arxiv.org/abs/2008.03465 with the implementation available at: https://github.com/hongweilibran/claustrum_multi_view Arguments --------- t1 : ANTsImage input 3-D T1 brain image. do_preprocessing : boolean perform n4 bias correction. use_ensemble : boolean check whether to use all 3 sets of weights. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Claustrum segmentation probability image Example ------- >>> image = ants.image_read("t1.nii.gz") >>> probability_mask = claustrum_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import brain_extraction from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_size = (180, 180) ################################ # # Preprocess images # ################################ number_of_channels = 1 t1_preprocessed = ants.image_clone(t1) brain_mask = ants.threshold_image(t1, 0, 0, 0, 1) if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] brain_mask = t1_preprocessing["brain_mask"] reference_image = ants.make_image((170, 256, 256), voxval=1, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = ants.get_center_of_mass(reference_image) center_of_mass_image = ants.get_center_of_mass(brain_mask) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image) brain_mask_warped = ants.threshold_image( ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image), 0.5, 1.1, 1, 0) ################################ # # Gaussian normalize intensity based on brain mask # ################################ mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean() std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std() t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 if verbose == True: print("Claustrum: retrieving axial model weights.") unet_axial_models = list() for i in range(number_of_models): weights_file_name = get_pretrained_network( "claustrum_axial_" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_axial_models.append( create_sysu_media_unet_model_2d((*image_size, number_of_channels), anatomy="claustrum")) unet_axial_models[i].load_weights(weights_file_name) if verbose == True: print("Claustrum: retrieving coronal model weights.") unet_coronal_models = list() for i in range(number_of_models): weights_file_name = get_pretrained_network( "claustrum_coronal_" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_coronal_models.append( create_sysu_media_unet_model_2d((*image_size, number_of_channels), anatomy="claustrum")) unet_coronal_models[i].load_weights(weights_file_name) ################################ # # Extract slices # ################################ dimensions_to_predict = [1, 2] batch_coronal_X = np.zeros( (t1_preprocessed_warped.shape[1], *image_size, number_of_channels)) batch_axial_X = np.zeros( (t1_preprocessed_warped.shape[2], *image_size, number_of_channels)) for d in range(len(dimensions_to_predict)): number_of_slices = t1_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), image_size) if dimensions_to_predict[d] == 1: batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1) else: batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy()) ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Coronal prediction.") prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction_coronal += unet_coronal_models[i].predict( batch_coronal_X, verbose=verbose) prediction_coronal /= number_of_models for i in range(t1_preprocessed_warped.shape[1]): prediction_coronal[i, :, :, 0] = np.rot90( np.squeeze(prediction_coronal[i, :, :, 0])) if verbose == True: print("Axial prediction.") prediction_axial = unet_axial_models[0].predict(batch_axial_X, verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction_axial += unet_axial_models[i].predict(batch_axial_X, verbose=verbose) prediction_axial /= number_of_models for i in range(t1_preprocessed_warped.shape[2]): prediction_axial[i, :, :, 0] = np.rot90(np.squeeze(prediction_axial[i, :, :, 0]), k=-1) if verbose == True: print("Restack image and transform back to native space.") permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0 for d in range(len(dimensions_to_predict)): which_batch_slices = range( t1_preprocessed_warped.shape[dimensions_to_predict[d]]) prediction_per_dimension = None if dimensions_to_predict[d] == 1: prediction_per_dimension = prediction_coronal[ which_batch_slices, :, :, :] else: prediction_per_dimension = prediction_axial[ which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( t1_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), t1_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0) return (probability_image)
def sysu_media_wmh_segmentation(flair, t1=None, use_ensemble=True, antsxnet_cache_directory=None, verbose=False): """ Perform WMH segmentation using the winning submission in the MICCAI 2017 challenge by the sysu_media team using FLAIR or T1/FLAIR. The MICCAI challenge is discussed in https://pubmed.ncbi.nlm.nih.gov/30908194/ with the sysu_media's team entry is discussed in https://pubmed.ncbi.nlm.nih.gov/30125711/ with the original implementation available here: https://github.com/hongweilibran/wmh_ibbmTum The original implementation used global thresholding as a quick brain extraction approach. Due to possible generalization difficulties, we leave such post-processing steps to the user. For brain or white matter masking see functions brain_extraction or deep_atropos, respectively. Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). use_ensemble : boolean check whether to use all 3 sets of weights. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size from ..utilities import preprocess_brain_image from ..utilities import binary_dice_coefficient if flair.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_size = (200, 200) ################################ # # Preprocess images # ################################ def closest_simplified_direction_matrix(direction): closest = (np.abs(direction) + 0.5).astype(int).astype(float) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(flair.direction) flair_preprocessing = preprocess_brain_image( flair, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] flair_preprocessed.set_direction(simplified_direction) flair_preprocessed.set_origin((0, 0, 0)) flair_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 1 t1_preprocessed = None if t1 is not None: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] t1_preprocessed.set_direction(simplified_direction) t1_preprocessed.set_origin((0, 0, 0)) t1_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 2 ################################ # # Reorient images # ################################ reference_image = ants.make_image((256, 256, 256), voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor( ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor( ants.get_center_of_mass(flair_preprocessed)) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_reference) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) flair_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, flair_preprocessed, reference_image, interpolation="nearestneighbor") crop_image = ants.image_clone(flair_preprocessed) * 0 + 1 crop_image_warped = ants.apply_ants_transform_to_image( xfrm, crop_image, reference_image, interpolation="nearestneighbor") flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped, crop_image_warped, 1) if t1 is not None: t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image, interpolation="nearestneighbor") t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped, crop_image_warped, 1) ################################ # # Gaussian normalize intensity # ################################ mean_flair = flair_preprocessed.mean() std_flair = flair_preprocessed.std() if number_of_channels == 2: mean_t1 = t1_preprocessed.mean() std_t1 = t1_preprocessed.std() flair_preprocessed_warped = (flair_preprocessed_warped - mean_flair) / std_flair if number_of_channels == 2: t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble == True: number_of_models = 3 if verbose == True: print("White matter hyperintensity: retrieving model weights.") unet_models = list() for i in range(number_of_models): if number_of_channels == 1: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairOnlyModel" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network( "sysuMediaWmhFlairT1Model" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_model = create_sysu_media_unet_model_2d( (*image_size, number_of_channels)) unet_loss = binary_dice_coefficient(smoothing_factor=1.) unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4), loss=unet_loss) unet_model.load_weights(weights_file_name) unet_models.append(unet_model) ################################ # # Extract slices # ################################ dimensions_to_predict = [2] total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed_warped.shape[ dimensions_to_predict[d]] batchX = np.zeros( (total_number_of_slices, *image_size, number_of_channels)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed_warped.shape[ dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size( ants.slice_image(flair_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count, :, :, 0] = flair_slice.numpy() if number_of_channels == 2: t1_slice = pad_or_crop_image_to_size( ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count, :, :, 1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_models[0].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction += unet_models[i].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) prediction /= number_of_models prediction = np.transpose(prediction, axes=(0, 2, 1, 3)) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed_warped.shape[ dimensions_to_predict[d]] which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices, :, :, :] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info( flair_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + ( prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, flair_preprocessed) probability_image = ants.copy_image_info(flair, probability_image) return (probability_image)
def brain_extraction(image, modality="t1", antsxnet_cache_directory=None, verbose=False): """ Perform brain extraction using U-net and ANTs-based training data. "NoBrainer" is also possible where brain extraction uses U-net and FreeSurfer training data ported from the https://github.com/neuronets/nobrainer-models Arguments --------- image : ANTsImage input image (or list of images for multi-modal scenarios). modality : string Modality image type. Options include: * "t1": T1-weighted MRI---ANTs-trained. Update from "t1v0". * "t1v0": T1-weighted MRI---ANTs-trained. * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk. * "t1combined": Brian's combination of "t1" and "t1nobrainer". One can also specify "t1combined[X]" where X is the morphological radius. X = 12 by default. * "flair": FLAIR MRI. * "t2": T2 MRI. * "bold": 3-D BOLD MRI. * "fa": Fractional anisotropy. * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner. * "t1infant": T1-w infant MRI h/t Martin Styner. * "t2infant": T2-w infant MRI h/t Martin Styner. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- ANTs probability brain mask image. Example ------- >>> probability_brain_mask = brain_extraction(brain_image, modality="t1") """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..architectures import create_nobrainer_unet_model_3d classes = ("background", "brain") number_of_classification_labels = len(classes) channel_size = 1 if isinstance(image, list): channel_size = len(image) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" input_images = list() if channel_size == 1: input_images.append(image) else: input_images = image if input_images[0].dimension != 3: raise ValueError("Image dimension must be 3.") if "t1combined" in modality: brain_extraction_t1 = brain_extraction( image, modality="t1", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = ants.iMath_get_largest_component( ants.threshold_image(brain_extraction_t1, 0.5, 10000)) # Need to change with voxel resolution morphological_radius = 12 if '[' in modality and ']' in modality: morphological_radius = int(modality.split("[")[1].split("]")[0]) brain_extraction_t1nobrainer = brain_extraction( image * ants.iMath_MD(brain_mask, radius=morphological_radius), modality="t1nobrainer", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_extraction_combined = ants.iMath_fill_holes( ants.iMath_get_largest_component(brain_extraction_t1nobrainer * brain_mask)) brain_extraction_combined = brain_extraction_combined + ants.iMath_ME( brain_mask, morphological_radius) + brain_mask return (brain_extraction_combined) if modality != "t1nobrainer": ##################### # # ANTs-based # ##################### weights_file_name_prefix = None if modality == "t1v0": weights_file_name_prefix = "brainExtraction" elif modality == "t1": weights_file_name_prefix = "brainExtractionT1" elif modality == "t2": weights_file_name_prefix = "brainExtractionT2" elif modality == "flair": weights_file_name_prefix = "brainExtractionFLAIR" elif modality == "bold": weights_file_name_prefix = "brainExtractionBOLD" elif modality == "fa": weights_file_name_prefix = "brainExtractionFA" elif modality == "t1t2infant": weights_file_name_prefix = "brainExtractionInfantT1T2" elif modality == "t1infant": weights_file_name_prefix = "brainExtractionInfantT1" elif modality == "t2infant": weights_file_name_prefix = "brainExtractionInfantT2" else: raise ValueError("Unknown modality type.") weights_file_name = get_pretrained_network( weights_file_name_prefix, antsxnet_cache_directory=antsxnet_cache_directory) if verbose == True: print("Brain extraction: retrieving template.") reorient_template_file_name_path = get_antsxnet_data( "S_template3", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape if modality == "t1": classes = ("background", "head", "brain") number_of_classification_labels = len(classes) unet_model = create_unet_model_3d( (*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5) unet_model.load_weights(weights_file_name) if verbose == True: print("Brain extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template) center_of_mass_image = ants.get_center_of_mass(input_images[0]) translation = np.asarray(center_of_mass_image) - np.asarray( center_of_mass_template) xfrm = ants.create_ants_transform( transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) batchX = np.zeros((1, *resampled_image_size, channel_size)) for i in range(len(input_images)): warped_image = ants.apply_ants_transform_to_image( xfrm, input_images[i], reorient_template) warped_array = warped_image.numpy() batchX[0, :, :, :, i] = (warped_array - warped_array.mean()) / warped_array.std() if verbose == True: print("Brain extraction: prediction and decoding.") predicted_data = unet_model.predict(batchX, verbose=verbose) origin = reorient_template.origin spacing = reorient_template.spacing direction = reorient_template.direction probability_images_array = list() probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 0]), origin=origin, spacing=spacing, direction=direction)) probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 1]), origin=origin, spacing=spacing, direction=direction)) if modality == "t1": probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 2]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print( "Brain extraction: renormalize probability mask to native space." ) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[number_of_classification_labels - 1], input_images[0]) return (probability_image) else: ##################### # # NoBrainer # ##################### if verbose == True: print("NoBrainer: generating network.") model = create_nobrainer_unet_model_3d((None, None, None, 1)) weights_file_name = get_pretrained_network( "brainExtractionNoBrainer", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) if verbose == True: print( "NoBrainer: preprocessing (intensity truncation and resampling)." ) image_array = image.numpy() image_robust_range = np.quantile( image_array[np.where(image_array != 0)], (0.02, 0.98)) threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0] ) + image_robust_range[0] thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0, 1) thresholded_image = image * thresholded_mask image_resampled = ants.resample_image(thresholded_image, (256, 256, 256), use_voxels=True) image_array = np.expand_dims(image_resampled.numpy(), axis=0) image_array = np.expand_dims(image_array, axis=-1) if verbose == True: print("NoBrainer: predicting mask.") brain_mask_array = np.squeeze( model.predict(image_array, verbose=verbose)) brain_mask_resampled = ants.copy_image_info( image_resampled, ants.from_numpy(brain_mask_array)) brain_mask_image = ants.resample_image(brain_mask_resampled, image.shape, use_voxels=True, interp_type=1) spacing = ants.get_spacing(image) spacing_product = spacing[0] * spacing[1] * spacing[2] minimum_brain_volume = round(649933.7 / spacing_product) brain_mask_labeled = ants.label_clusters(brain_mask_image, minimum_brain_volume) return (brain_mask_labeled)