def test_ndimage_to_list(self): image = ants.image_read(ants.get_ants_data('r16')) image2 = ants.image_read(ants.get_ants_data('r64')) ants.set_spacing(image, (2, 2)) ants.set_spacing(image2, (2, 2)) imageTar = ants.make_image((*image2.shape, 2)) ants.set_spacing(imageTar, (2, 2, 2)) image3 = ants.list_to_ndimage(imageTar, [image, image2]) self.assertEqual(image3.dimension, 3) ants.set_direction(image3, np.eye(3) * 2) images_unmerged = ants.ndimage_to_list(image3) self.assertEqual(len(images_unmerged), 2) self.assertEqual(images_unmerged[0].dimension, 2)
output_file_name = args[2] start_time_total = time.time() print("Reading ", input_file_name) start_time = time.time() input_image = ants.image_read(input_file_name) end_time = time.time() elapsed_time = end_time - start_time print(" (elapsed time: ", elapsed_time, " seconds)") dimension = len(input_image.shape) input_image_list = list() if dimension == 4: input_image_list = ants.ndimage_to_list(input_image) elif dimension == 2: raise ValueError("Model for 3-D or 4-D images only.") elif dimension == 3: input_image_list.append(input_image) model = antspynet.create_deep_back_projection_network_model_3d( (*input_image_list[0].shape, 1), number_of_outputs=1, number_of_base_filters=64, number_of_feature_filters=256, number_of_back_projection_stages=7, convolution_kernel_size=(3, 3, 3), strides=(2, 2, 2), number_of_loss_functions=1)
def lung_extraction(image, modality="proton", antsxnet_cache_directory=None, verbose=False): """ Perform proton or ct lung extraction using U-net. Arguments --------- image : ANTsImage input image modality : string Modality image type. Options include "ct", "proton", "protonLobes", "maskLobes", and "ventilation". antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Dictionary of ANTs segmentation and probability images. Example ------- >>> output = lung_extraction(lung_image, modality="proton") """ from ..architectures import create_unet_model_2d from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import pad_or_crop_image_to_size if image.dimension != 3: raise ValueError( "Image dimension must be 3." ) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" image_mods = [modality] channel_size = len(image_mods) weights_file_name = None unet_model = None if modality == "proton": weights_file_name = get_pretrained_network("protonLungMri", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left_lung", "right_lung") number_of_classification_labels = len(classes) reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5)) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) batchX = np.expand_dims(warped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability mask to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) if modality == "protonLobes" or modality == "maskLobes": reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) channel_size = 1 + len(priors_image_list) number_of_classification_labels = 1 + len(priors_image_list) unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters_at_base_layer=16, number_of_layers=4, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",)) if modality == "protonLobes": penultimate_layer = unet_model.layers[-2].output outputs2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2]) weights_file_name = get_pretrained_network("protonLobes", antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network("maskLobes", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) if verbose == True: print("Lung extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1) center_of_mass_image = ants.get_center_of_mass(image * 0 + 1) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template) warped_array = warped_image.numpy() if modality == "protonLobes": warped_array = (warped_array - warped_array.mean()) / warped_array.std() else: warped_array[warped_array != 0] = 1 batchX = np.zeros((1, *warped_array.shape, channel_size)) batchX[0,:,:,:,0] = warped_array for i in range(len(priors_image_list)): batchX[0,:,:,:,i+1] = priors_image_list[i].numpy() predicted_data = unet_model.predict(batchX, verbose=int(verbose)) origin = warped_image.origin spacing = warped_image.spacing direction = warped_image.direction probability_images_array = list() for i in range(number_of_classification_labels): if modality == "protonLobes": probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) else: probability_images_array.append( ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction)) if verbose == True: print("Lung extraction: renormalize probability images to native space.") for i in range(number_of_classification_labels): probability_images_array[i] = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_images_array[i], image) image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] if modality == "protonLobes": whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]), origin=origin, spacing=spacing, direction=direction) whole_lung_mask = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), whole_lung_mask, image) return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array, 'whole_lung_mask_image' : whole_lung_mask} return(return_dict) else: return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_array} return(return_dict) elif modality == "ct": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess CT image.") def closest_simplified_direction_matrix(direction): closest = np.floor(np.abs(direction) + 0.5) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(image.direction) reference_image_size = (128, 128, 128) ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0) ct_preprocessed[ct_preprocessed < -1000] = -1000 ct_preprocessed[ct_preprocessed > 400] = 400 ct_preprocessed.set_direction(simplified_direction) ct_preprocessed.set_origin((0, 0, 0)) ct_preprocessed.set_spacing((1, 1, 1)) ################################ # # Reorient image # ################################ reference_image = ants.make_image(reference_image_size, voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1)) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) / (ct_preprocessed.max() - ct_preprocessed.min())) ct_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor") ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) / (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5 ################################ # # Build models and load weights # ################################ if verbose == True: print("Build model and load weights.") weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights", antsxnet_cache_directory=antsxnet_cache_directory) classes = ("background", "left lung", "right lung", "airways") number_of_classification_labels = len(classes) luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors"))) for i in range(len(luna16_priors)): luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True) channel_size = len(luna16_priors) + 1 unet_model = create_unet_model_3d((*reference_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating",)) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") batchX = np.zeros((1, *reference_image_size, channel_size)) batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy() for i in range(len(luna16_priors)): batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5 predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(number_of_classification_labels): if verbose == True: print("Reconstructing image", classes[i]) probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]), origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing, direction=ct_preprocessed_warped.direction) probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed) probability_image = ants.resample_image(probability_image, resample_params=image.shape, use_voxels=True, interp_type=0) probability_image = ants.copy_image_info(image, probability_image) probability_images.append(probability_image) image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict) elif modality == "ventilation": ################################ # # Preprocess image # ################################ if verbose == True: print("Preprocess ventilation image.") template_size = (256, 256) image_modalities = ("Ventilation",) channel_size = len(image_modalities) preprocessed_image = (image - image.mean()) / image.std() ants.set_direction(preprocessed_image, np.identity(3)) ################################ # # Build models and load weights # ################################ unet_model = create_unet_model_2d((*template_size, channel_size), number_of_outputs=1, mode='sigmoid', number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0) if verbose == True: print("Whole lung mask: retrieving model weights.") weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ spacing = ants.get_spacing(preprocessed_image) dimensions_to_predict = (spacing.index(max(spacing)),) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,0] = ventilation_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) probability_image = ants.image_clone(image) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1 which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices,:,:,0] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(image, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), image.shape)) probability_image = probability_image + (prediction_image - probability_image) / (d + 1) current_start_slice = current_end_slice + 1 return(probability_image) else: return ValueError("Unrecognized modality.")
networks = powers_areal_mni_itk['SystemName'].unique() dfnpts = np.where( powers_areal_mni_itk['SystemName'] == networks[5] ) dfnImg = ants.mask_image( ptImg, ptImg, level = dfnpts[0].tolist(), binarize=False ) # plot( und, ptImg, axis=3, window.overlay = range( ptImg ) ) bold2ch2 = ants.apply_transforms( ch2, und, concatx2, whichtoinvert = ( True, False, True, False ) ) # Extracting canonical functional network maps ## preprocessing csfAndWM = ( ants.threshold_image( boldseg, 1, 1 ) + ants.threshold_image( boldseg, 3, 3 ) ).morphology("erode",1) bold = ants.image_read( boldfnsR ) boldList = ants.ndimage_to_list( bold ) avgBold = ants.get_average_of_timeseries( bold, range( 5 ) ) boldUndTX = ants.registration( und, avgBold, "SyN", regIterations = (15,4), synMetric = "CC", synSampling = 2, verbose = False ) boldUndTS = ants.apply_transforms( und, bold, boldUndTX['fwdtransforms'], imagetype = 3 ) motCorr = ants.motion_correction( boldUndTS, avgBold, type_of_transform="Rigid", verbose = True ) tr = ants.get_spacing( bold )[3] highMotionTimes = np.where( motCorr['FD'] >= 0.5 ) goodtimes = np.where( motCorr['FD'] < 0.5 ) avgBold = ants.get_average_of_timeseries( motCorr['motion_corrected'], range( 5 ) ) ####################### nt = len(motCorr['FD']) plt.plot( range( nt ), motCorr['FD'] ) plt.show() #################################################
def desikan_killiany_tourville_labeling(t1, do_preprocessing=True, return_probability_images=False, antsxnet_cache_directory=None, verbose=False): """ Cortical and deep gray matter labeling using Desikan-Killiany-Tourville Perform DKT labeling using deep learning The labeling is as follows: Inner labels: Label 0: background Label 4: left lateral ventricle Label 5: left inferior lateral ventricle Label 6: left cerebellem exterior Label 7: left cerebellum white matter Label 10: left thalamus proper Label 11: left caudate Label 12: left putamen Label 13: left pallidium Label 15: 4th ventricle Label 16: brain stem Label 17: left hippocampus Label 18: left amygdala Label 24: CSF Label 25: left lesion Label 26: left accumbens area Label 28: left ventral DC Label 30: left vessel Label 43: right lateral ventricle Label 44: right inferior lateral ventricle Label 45: right cerebellum exterior Label 46: right cerebellum white matter Label 49: right thalamus proper Label 50: right caudate Label 51: right putamen Label 52: right palladium Label 53: right hippocampus Label 54: right amygdala Label 57: right lesion Label 58: right accumbens area Label 60: right ventral DC Label 62: right vessel Label 72: 5th ventricle Label 85: optic chasm Label 91: left basal forebrain Label 92: right basal forebrain Label 630: cerebellar vermal lobules I-V Label 631: cerebellar vermal lobules VI-VII Label 632: cerebellar vermal lobules VIII-X Outer labels: Label 1002: left caudal anterior cingulate Label 1003: left caudal middle frontal Label 1005: left cuneus Label 1006: left entorhinal Label 1007: left fusiform Label 1008: left inferior parietal Label 1009: left inferior temporal Label 1010: left isthmus cingulate Label 1011: left lateral occipital Label 1012: left lateral orbitofrontal Label 1013: left lingual Label 1014: left medial orbitofrontal Label 1015: left middle temporal Label 1016: left parahippocampal Label 1017: left paracentral Label 1018: left pars opercularis Label 1019: left pars orbitalis Label 1020: left pars triangularis Label 1021: left pericalcarine Label 1022: left postcentral Label 1023: left posterior cingulate Label 1024: left precentral Label 1025: left precuneus Label 1026: left rostral anterior cingulate Label 1027: left rostral middle frontal Label 1028: left superior frontal Label 1029: left superior parietal Label 1030: left superior temporal Label 1031: left supramarginal Label 1034: left transverse temporal Label 1035: left insula Label 2002: right caudal anterior cingulate Label 2003: right caudal middle frontal Label 2005: right cuneus Label 2006: right entorhinal Label 2007: right fusiform Label 2008: right inferior parietal Label 2009: right inferior temporal Label 2010: right isthmus cingulate Label 2011: right lateral occipital Label 2012: right lateral orbitofrontal Label 2013: right lingual Label 2014: right medial orbitofrontal Label 2015: right middle temporal Label 2016: right parahippocampal Label 2017: right paracentral Label 2018: right pars opercularis Label 2019: right pars orbitalis Label 2020: right pars triangularis Label 2021: right pericalcarine Label 2022: right postcentral Label 2023: right posterior cingulate Label 2024: right precentral Label 2025: right precuneus Label 2026: right rostral anterior cingulate Label 2027: right rostral middle frontal Label 2028: right superior frontal Label 2029: right superior parietal Label 2030: right superior temporal Label 2031: right supramarginal Label 2034: right transverse temporal Label 2035: right insula Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. return_probability_images : boolean Whether to return the two sets of probability images for the inner and outer labels. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = desikan_killiany_tourville_labeling(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import categorical_focal_loss from ..utilities import preprocess_brain_image from ..utilities import crop_image_center if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), do_brain_extraction=True, template="croppedMni152", template_transform_type="AffineFast", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing[ "preprocessed_image"] * t1_preprocessing['brain_mask'] ################################ # # Download spatial priors for outer model # ################################ spatial_priors_file_name_path = get_antsxnet_data( "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) ################################ # # Build outer model and load weights # ################################ template_size = (96, 112, 96) labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003, *tuple(range(2005, 2032)), 2034, 2035) channel_size = 1 + len(priors_image_list) unet_model = create_unet_model_3d((*template_size, channel_size), number_of_outputs=len(labels), number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, add_attention_gating=True) weights_file_name = None weights_file_name = get_pretrained_network( "dktOuterWithSpatialPriors", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Outer model Prediction.") downsampled_image = ants.resample_image(t1_preprocessed, template_size, use_voxels=True, interp_type=0) image_array = downsampled_image.numpy() image_array = (image_array - image_array.mean()) / image_array.std() batchX = np.zeros((1, *template_size, channel_size)) batchX[0, :, :, :, 0] = image_array for i in range(len(priors_image_list)): resampled_prior_image = ants.resample_image(priors_image_list[i], template_size, use_voxels=True, interp_type=0) batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = downsampled_image.origin spacing = downsampled_image.spacing direction = downsampled_image.direction inner_probability_images = list() for i in range(len(labels)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) resampled_image = ants.resample_image(probability_image, t1_preprocessed.shape, use_voxels=True, interp_type=0) if do_preprocessing == True: inner_probability_images.append( ants.apply_transforms( fixed=t1, moving=resampled_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: inner_probability_images.append(resampled_image) image_matrix = ants.image_list_to_matrix(inner_probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] dkt_label_image = ants.image_clone(segmentation_image) for i in range(len(labels)): dkt_label_image[segmentation_image == i] = labels[i] ################################ # # Build inner model and load weights # ################################ template_size = (160, 192, 160) labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30, 43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631, 632) unet_model = create_unet_model_3d((*template_size, 1), number_of_outputs=len(labels), number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, add_attention_gating=True) weights_file_name = get_pretrained_network( "dktInner", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0), (172, 206, 160)) batchX = np.expand_dims(cropped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction outer_probability_images = list() for i in range(len(labels)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing == True: outer_probability_images.append( ants.apply_transforms( fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: outer_probability_images.append(decropped_image) image_matrix = ants.image_list_to_matrix(outer_probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] ################################ # # Incorporate the inner model results into the final label image. # Note that we purposely prioritize the inner label results. # ################################ for i in range(len(labels)): if labels[i] > 0: dkt_label_image[segmentation_image == i] = labels[i] if return_probability_images == True: return_dict = { 'segmentation_image': dkt_label_image, 'inner_probability_images': inner_probability_images, 'outer_probability_images': outer_probability_images } return (return_dict) else: return (dkt_label_image)
def deep_flash(t1, t2=None, do_preprocessing=True, use_rank_intensity=True, antsxnet_cache_directory=None, verbose=False): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using labels from Mike Yassa's lab https://faculty.sites.uci.edu/myassa/ The labeling is as follows: Label 0 : background Label 5 : left aLEC Label 6 : right aLEC Label 7 : left pMEC Label 8 : right pMEC Label 9 : left perirhinal Label 10: right perirhinal Label 11: left parahippocampal Label 12: right parahippocampal Label 13: left DG/CA2/CA3/CA4 Label 14: right DG/CA2/CA3/CA4 Label 15: left CA1 Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum Preprocessing on the training data consisted of: * n4 bias correction, * affine registration to the "deep flash" template. which is performed on the input images if do_preprocessing = True. Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. t2 : ANTsImage Optional 3-D T2-weighted brain image. If specified, it is assumed to be pre-aligned to the t1. do_preprocessing : boolean See description above. use_rank_intensity : boolean If false, use histogram matching with cropped template ROI. Otherwise, use a rank intensity transform on the cropped ROI. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label and foreground. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_flash(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import brain_extraction if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Options temporarily taken from the user # ################################ # use_hierarchical_parcellation : boolean # If True, use u-net model with additional outputs of the medial temporal lobe # region, hippocampal, and entorhinal/perirhinal/parahippocampal regions. Otherwise # the only additional output is the medial temporal lobe. # # use_contralaterality : boolean # Use both hemispherical models to also predict the corresponding contralateral # segmentation and use both sets of priors to produce the results. Mainly used # for debugging. use_hierarchical_parcellation = True use_contralaterality = True ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_mask = None t1_preprocessed_flipped = None t1_template = ants.image_read( get_antsxnet_data("deepFlashTemplateT1SkullStripped")) template_transforms = None if do_preprocessing: if verbose == True: print("Preprocessing T1.") # Brain extraction probability_mask = brain_extraction( t1_preprocessed, modality="t1", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessed * t1_mask # Do bias correction t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template registration = ants.registration( fixed=t1_template, moving=t1_preprocessed, type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) template_transforms = dict(fwdtransforms=registration['fwdtransforms'], invtransforms=registration['invtransforms']) t1_preprocessed = registration['warpedmovout'] if use_contralaterality: t1_preprocessed_array = t1_preprocessed.numpy() t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) t1_preprocessed_flipped = ants.from_numpy( t1_preprocessed_array_flipped, origin=t1_preprocessed.origin, spacing=t1_preprocessed.spacing, direction=t1_preprocessed.direction) t2_preprocessed = t2 t2_preprocessed_flipped = None t2_template = None if t2 is not None: t2_template = ants.image_read( get_antsxnet_data("deepFlashTemplateT2SkullStripped")) t2_template = ants.copy_image_info(t1_template, t2_template) if do_preprocessing: if verbose == True: print("Preprocessing T2.") # Brain extraction t2_preprocessed = t2_preprocessed * t1_mask # Do bias correction t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template t2_preprocessed = ants.apply_transforms( fixed=t1_template, moving=t2_preprocessed, transformlist=template_transforms['fwdtransforms'], verbose=verbose) if use_contralaterality: t2_preprocessed_array = t2_preprocessed.numpy() t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array, axis=0) t2_preprocessed_flipped = ants.from_numpy( t2_preprocessed_array_flipped, origin=t2_preprocessed.origin, spacing=t2_preprocessed.spacing, direction=t2_preprocessed.direction) probability_images = list() labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) image_size = (64, 64, 96) ################################ # # Process left/right in split networks # ################################ ################################ # # Download spatial priors # ################################ spatial_priors_file_name_path = get_antsxnet_data( "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) for i in range(len(priors_image_list)): priors_image_list[i] = ants.copy_image_info(t1_preprocessed, priors_image_list[i]) labels_left = labels[1::2] priors_image_left_list = priors_image_list[1::2] probability_images_left = list() foreground_probability_images_left = list() lower_bound_left = (76, 74, 56) upper_bound_left = (140, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) origin_left = tmp_cropped.origin spacing = tmp_cropped.spacing direction = tmp_cropped.direction t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min( )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0 t2_template_roi_left = None if t2_template is not None: t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left, upper_bound_left) t2_template_roi_left = (t2_template_roi_left - t2_template_roi_left.min()) / ( t2_template_roi_left.max() - t2_template_roi_left.min()) * 2.0 - 1.0 labels_right = labels[2::2] priors_image_right_list = priors_image_list[2::2] probability_images_right = list() foreground_probability_images_right = list() lower_bound_right = (20, 74, 56) upper_bound_right = (84, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) origin_right = tmp_cropped.origin t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) t1_template_roi_right = ( t1_template_roi_right - t1_template_roi_right.min() ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0 t2_template_roi_right = None if t2_template is not None: t2_template_roi_right = ants.crop_indices(t2_template, lower_bound_right, upper_bound_right) t2_template_roi_right = (t2_template_roi_right - t2_template_roi_right.min()) / ( t2_template_roi_right.max() - t2_template_roi_right.min()) * 2.0 - 1.0 ################################ # # Create model # ################################ channel_size = 1 + len(labels_left) if t2 is not None: channel_size += 1 number_of_classification_labels = 1 + len(labels_left) unet_model = create_unet_model_3d( (*image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(32, 64, 96, 128, 256), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) penultimate_layer = unet_model.layers[-2].output # medial temporal lobe output1 = Conv3D( filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) if use_hierarchical_parcellation: # EC, perirhinal, and parahippo. output2 = Conv3D( filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) # Hippocampus output3 = Conv3D( filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model( inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) else: unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1]) ################################ # # Left: build model and load weights # ################################ network_name = 'deepFlashLeftT1' if t2 is not None: network_name = 'deepFlashLeftBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network( network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[0, :, :, :, 0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[1, :, :, :, 0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) batchX[0, :, :, :, 1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image( t2_cropped, t2_template_roi_left, 255, 64, False) batchX[1, :, :, :, 1] = t2_cropped.numpy() for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) for j in range(batchX.shape[0]): batchX[j, :, :, :, i + (channel_size - len(labels_left))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_left)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_left, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy( probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms( fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped probability_images_left.append(probability_image) else: # flipped probability_images_right.append(probability_image) ################################ # # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_left, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy( probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms( fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped foreground_probability_images_left.append(probability_image) else: foreground_probability_images_right.append(probability_image) ################################ # # Right: build model and load weights # ################################ network_name = 'deepFlashRightT1' if t2 is not None: network_name = 'deepFlashRightBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (right).") weights_file_name = get_pretrained_network( network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[0, :, :, :, 0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[1, :, :, :, 0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) batchX[0, :, :, :, 1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image( t2_cropped, t2_template_roi_right, 255, 64, False) batchX[1, :, :, :, 1] = t2_cropped.numpy() for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) for j in range(batchX.shape[0]): batchX[j, :, :, :, i + (channel_size - len(labels_right))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_right)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_right, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy( probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms( fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: probability_images_right[i] = ( probability_images_right[i] + probability_image) / 2 else: probability_images_right.append(probability_image) else: # flipped probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 ################################ # # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_right, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy( probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms( fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: foreground_probability_images_right[ i - 1] = (foreground_probability_images_right[i - 1] + probability_image) / 2 else: foreground_probability_images_right.append( probability_image) else: foreground_probability_images_left[ i - 1] = (foreground_probability_images_left[i - 1] + probability_image) / 2 ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix( probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0) ]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image == i] = labels[i] foreground_probability_images = list() for i in range(len(foreground_probability_images_left)): foreground_probability_images.append( foreground_probability_images_left[i] + foreground_probability_images_right[i]) return_dict = None if use_hierarchical_parcellation: return_dict = { 'segmentation_image': relabeled_image, 'probability_images': probability_images, 'medial_temporal_lobe_probability_image': foreground_probability_images[0], 'other_region_probability_image': foreground_probability_images[1], 'hippocampal_probability_image': foreground_probability_images[2] } else: return_dict = { 'segmentation_image': relabeled_image, 'probability_images': probability_images, 'medial_temporal_lobe_probability_image': foreground_probability_images[0] } return (return_dict)
def deep_flash_deprecated(t1, do_preprocessing=True, do_per_hemisphere=True, which_hemisphere_models="new", antsxnet_cache_directory=None, verbose=False): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" Perform hippocampal/entorhinal segmentation in T1 images using labels from Mike Yassa's lab https://faculty.sites.uci.edu/myassa/ The labeling is as follows: Label 0 : background Label 5 : left aLEC Label 6 : right aLEC Label 7 : left pMEC Label 8 : right pMEC Label 9 : left perirhinal Label 10: right perirhinal Label 11: left parahippocampal Label 12: right parahippocampal Label 13: left DG/CA3 Label 14: right DG/CA3 Label 15: left CA1 Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. do_per_hemisphere : boolean If True, do prediction based on separate networks per hemisphere. Otherwise, use the single network trained for both hemispheres. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_flash(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size print("This function is deprecated. Please update to deep_flash().") if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing[ "preprocessed_image"] * t1_preprocessing['brain_mask'] probability_images = list() labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) ################################ # # Process left/right in same network # ################################ if do_per_hemisphere == False: ################################ # # Build model and load weights # ################################ template_size = (160, 192, 160) unet_model = create_unet_model_3d( (*template_size, 1), number_of_outputs=len(labels), number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating", )) if verbose: print("DeepFlash: retrieving model weights.") weights_file_name = get_pretrained_network( "deepFlash", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose: print("Prediction.") cropped_image = pad_or_crop_image_to_size(t1_preprocessed, template_size) batchX = np.expand_dims(cropped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction for i in range(len(labels)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images.append( ants.apply_transforms( fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(decropped_image) ################################ # # Process left/right in split networks # ################################ else: ################################ # # Left: download spatial priors # ################################ spatial_priors_left_file_name_path = get_antsxnet_data( "priorDeepFlashLeftLabels", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors_left = ants.image_read( spatial_priors_left_file_name_path) priors_image_left_list = ants.ndimage_to_list(spatial_priors_left) ################################ # # Left: build model and load weights # ################################ template_size = (64, 96, 96) labels_left = (0, 5, 7, 9, 11, 13, 15, 17) channel_size = 1 + len(labels_left) number_of_filters = 16 network_name = '' if which_hemisphere_models == "old": network_name = "deepFlashLeft16" elif which_hemisphere_models == "new": network_name = "deepFlashLeft16new" else: raise ValueError("network_name must be \"old\" or \"new\".") unet_model = create_unet_model_3d( (*template_size, channel_size), number_of_outputs=len(labels_left), number_of_layers=4, number_of_filters_at_base_layer=number_of_filters, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating", )) if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network( network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0), (94, 147, 96)) image_array = cropped_image.numpy() image_array = (image_array - image_array.mean()) / image_array.std() batchX = np.zeros((1, *template_size, channel_size)) batchX[0, :, :, :, 0] = image_array for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], (30, 51, 0), (94, 147, 96)) batchX[0, :, :, :, i + 1] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction probability_images_left = list() for i in range(len(labels_left)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images_left.append( ants.apply_transforms( fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images_left.append(decropped_image) ################################ # # Right: download spatial priors # ################################ spatial_priors_right_file_name_path = get_antsxnet_data( "priorDeepFlashRightLabels", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors_right = ants.image_read( spatial_priors_right_file_name_path) priors_image_right_list = ants.ndimage_to_list(spatial_priors_right) ################################ # # Right: build model and load weights # ################################ template_size = (64, 96, 96) labels_right = (0, 6, 8, 10, 12, 14, 16, 18) channel_size = 1 + len(labels_right) number_of_filters = 16 network_name = '' if which_hemisphere_models == "old": network_name = "deepFlashRight16" elif which_hemisphere_models == "new": network_name = "deepFlashRight16new" else: raise ValueError("network_name must be \"old\" or \"new\".") unet_model = create_unet_model_3d( (*template_size, channel_size), number_of_outputs=len(labels_right), number_of_layers=4, number_of_filters_at_base_layer=number_of_filters, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating", )) weights_file_name = get_pretrained_network( network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0), (152, 147, 96)) image_array = cropped_image.numpy() image_array = (image_array - image_array.mean()) / image_array.std() batchX = np.zeros((1, *template_size, channel_size)) batchX[0, :, :, :, 0] = image_array for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], (88, 51, 0), (152, 147, 96)) batchX[0, :, :, :, i + 1] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction probability_images_right = list() for i in range(len(labels_right)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images_right.append( ants.apply_transforms( fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images_right.append(decropped_image) ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix( probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0) ]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image == i] = labels[i] return_dict = { 'segmentation_image': relabeled_image, 'probability_images': probability_images } return (return_dict)
def deep_atropos(t1, do_preprocessing=True, use_spatial_priors=1, antsxnet_cache_directory=None, verbose=False): """ Six-tissue segmentation. Perform Atropos-style six tissue segmentation using deep learning. The labeling is as follows: Label 0 : background Label 1 : CSF Label 2 : gray matter Label 3 : white matter Label 4 : deep gray matter Label 5 : brain stem Label 6 : cerebellum Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. use_spatial_priors : integer Use MNI spatial tissue priors (0 or 1). Currently, only '0' (no priors) and '1' (cerebellar prior only) are the only two options. Default is 1. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_atropos(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import preprocess_brain_image from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image( t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing[ "preprocessed_image"] * t1_preprocessing['brain_mask'] ################################ # # Build model and load weights # ################################ patch_size = (112, 112, 112) stride_length = (t1_preprocessed.shape[0] - patch_size[0], t1_preprocessed.shape[1] - patch_size[1], t1_preprocessed.shape[2] - patch_size[2]) classes = ("background", "csf", "gray matter", "white matter", "deep gray matter", "brain stem", "cerebellum") mni_priors = None channel_size = 1 if use_spatial_priors != 0: mni_priors = ants.ndimage_to_list( ants.image_read( get_antsxnet_data( "croppedMni152Priors", antsxnet_cache_directory=antsxnet_cache_directory))) for i in range(len(mni_priors)): mni_priors[i] = ants.copy_image_info(t1_preprocessed, mni_priors[i]) channel_size = 2 unet_model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=len(classes), mode="classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating")) if verbose == True: print("DeepAtropos: retrieving model weights.") weights_file_name = '' if use_spatial_priors == 0: weights_file_name = get_pretrained_network( "sixTissueOctantBrainSegmentation", antsxnet_cache_directory=antsxnet_cache_directory) elif use_spatial_priors == 1: weights_file_name = get_pretrained_network( "sixTissueOctantBrainSegmentationWithPriors1", antsxnet_cache_directory=antsxnet_cache_directory) else: raise ValueError("use_spatial_priors must be a 0 or 1") unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() image_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX = np.zeros((*image_patches.shape, channel_size)) batchX[:, :, :, :, 0] = image_patches if channel_size > 1: prior_patches = extract_image_patches(mni_priors[6], patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[:, :, :, :, 1] = prior_patches predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(len(classes)): if verbose == True: print("Reconstructing image", classes[i]) reconstructed_image = reconstruct_image_from_patches( predicted_data[:, :, :, :, i], domain_image=t1_preprocessed, stride_length=stride_length) if do_preprocessing == True: probability_images.append( ants.apply_transforms( fixed=t1, moving=reconstructed_image, transformlist=t1_preprocessing['template_transforms'] ['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] return_dict = { 'segmentation_image': segmentation_image, 'probability_images': probability_images } return (return_dict)