コード例 #1
0
ファイル: test_utils.py プロジェクト: xpjiang/ANTsPy
 def test_ndimage_to_list(self):
     image = ants.image_read(ants.get_ants_data('r16'))
     image2 = ants.image_read(ants.get_ants_data('r64'))
     ants.set_spacing(image, (2, 2))
     ants.set_spacing(image2, (2, 2))
     imageTar = ants.make_image((*image2.shape, 2))
     ants.set_spacing(imageTar, (2, 2, 2))
     image3 = ants.list_to_ndimage(imageTar, [image, image2])
     self.assertEqual(image3.dimension, 3)
     ants.set_direction(image3, np.eye(3) * 2)
     images_unmerged = ants.ndimage_to_list(image3)
     self.assertEqual(len(images_unmerged), 2)
     self.assertEqual(images_unmerged[0].dimension, 2)
コード例 #2
0
    output_file_name = args[2]

start_time_total = time.time()

print("Reading ", input_file_name)
start_time = time.time()
input_image = ants.image_read(input_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

dimension = len(input_image.shape)

input_image_list = list()
if dimension == 4:
    input_image_list = ants.ndimage_to_list(input_image)
elif dimension == 2:
    raise ValueError("Model for 3-D or 4-D images only.")
elif dimension == 3:
    input_image_list.append(input_image)

model = antspynet.create_deep_back_projection_network_model_3d(
    (*input_image_list[0].shape, 1),
    number_of_outputs=1,
    number_of_base_filters=64,
    number_of_feature_filters=256,
    number_of_back_projection_stages=7,
    convolution_kernel_size=(3, 3, 3),
    strides=(2, 2, 2),
    number_of_loss_functions=1)
コード例 #3
0
ファイル: lung_extraction.py プロジェクト: ANTsX/ANTsPyNet
def lung_extraction(image,
                    modality="proton",
                    antsxnet_cache_directory=None,
                    verbose=False):

    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct", "proton", "protonLobes", 
        "maskLobes", and "ventilation".

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import pad_or_crop_image_to_size

    if image.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        weights_file_name = get_pretrained_network("protonLungMri",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability mask to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images_array}
        return(return_dict)

    if modality == "protonLobes" or modality == "maskLobes":
        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors = ants.image_read(spatial_priors_file_name_path)
        priors_image_list = ants.ndimage_to_list(spatial_priors)

        channel_size = 1 + len(priors_image_list)
        number_of_classification_labels = 1 + len(priors_image_list)

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification", 
            number_of_filters_at_base_layer=16, number_of_layers=4,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",))

        if modality == "protonLobes":
            penultimate_layer = unet_model.layers[-2].output
            outputs2 = Conv3D(filters=1,
                            kernel_size=(1, 1, 1),
                            activation='sigmoid',
                            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)
            unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2])
            weights_file_name = get_pretrained_network("protonLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network("maskLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)
        warped_array = warped_image.numpy()
        if modality == "protonLobes":
            warped_array = (warped_array - warped_array.mean()) / warped_array.std()
        else:
            warped_array[warped_array != 0] = 1
       
        batchX = np.zeros((1, *warped_array.shape, channel_size))
        batchX[0,:,:,:,0] = warped_array
        for i in range(len(priors_image_list)):
            batchX[0,:,:,:,i+1] = priors_image_list[i].numpy()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            if modality == "protonLobes":
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))
            else:
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability images to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        if modality == "protonLobes":
            whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]),
                origin=origin, spacing=spacing, direction=direction)
            whole_lung_mask = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), whole_lung_mask, image)

            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array,
                           'whole_lung_mask_image' : whole_lung_mask}
            return(return_dict)
        else:
            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array}
            return(return_dict)


    elif modality == "ct":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess CT image.")

        def closest_simplified_direction_matrix(direction):
            closest = np.floor(np.abs(direction) + 0.5)
            closest[direction < 0] *= -1.0
            return closest

        simplified_direction = closest_simplified_direction_matrix(image.direction)

        reference_image_size = (128, 128, 128)

        ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0)
        ct_preprocessed[ct_preprocessed < -1000] = -1000
        ct_preprocessed[ct_preprocessed > 400] = 400
        ct_preprocessed.set_direction(simplified_direction)
        ct_preprocessed.set_origin((0, 0, 0))
        ct_preprocessed.set_spacing((1, 1, 1))

        ################################
        #
        # Reorient image
        #
        ################################

        reference_image = ants.make_image(reference_image_size,
                                          voxval=0,
                                          spacing=(1, 1, 1),
                                          origin=(0, 0, 0),
                                          direction=np.identity(3))
        center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1))
        center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1))
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_reference), translation=translation)
        ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) /
            (ct_preprocessed.max() - ct_preprocessed.min()))
        ct_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor")
        ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) /
            (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5

        ################################
        #
        # Build models and load weights
        #
        ################################

        if verbose == True:
            print("Build model and load weights.")

        weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left lung", "right lung", "airways")
        number_of_classification_labels = len(classes)

        luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors")))
        for i in range(len(luna16_priors)):
            luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True)
        channel_size = len(luna16_priors) + 1

        unet_model = create_unet_model_3d((*reference_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification",
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5, additional_options=("attentionGating",))
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("Prediction.")

        batchX = np.zeros((1, *reference_image_size, channel_size))
        batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy()

        for i in range(len(luna16_priors)):
            batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(number_of_classification_labels):
            if verbose == True:
                print("Reconstructing image", classes[i])
            probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]),
                origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing,
                direction=ct_preprocessed_warped.direction)
            probability_image = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed)
            probability_image = ants.resample_image(probability_image,
               resample_params=image.shape, use_voxels=True, interp_type=0)
            probability_image = ants.copy_image_info(image, probability_image)
            probability_images.append(probability_image)

        image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images}
        return(return_dict)

    elif modality == "ventilation":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess ventilation image.")

        template_size = (256, 256)

        image_modalities = ("Ventilation",)
        channel_size = len(image_modalities)

        preprocessed_image = (image - image.mean()) / image.std()
        ants.set_direction(preprocessed_image, np.identity(3))

        ################################
        #
        # Build models and load weights
        #
        ################################

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs=1, mode='sigmoid',
            number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0,
            convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2),
            weight_decay=0)

        if verbose == True:
            print("Whole lung mask: retrieving model weights.")

        weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        spacing = ants.get_spacing(preprocessed_image)
        dimensions_to_predict = (spacing.index(max(spacing)),)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = ventilation_slice.numpy()
                slice_count += 1

        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        probability_image = ants.image_clone(image) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)

            prediction_per_dimension = prediction[which_batch_slices,:,:,0]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                image.shape))
            probability_image = probability_image + (prediction_image - probability_image) / (d + 1)

            current_start_slice = current_end_slice + 1

        return(probability_image)

    else:
        return ValueError("Unrecognized modality.")
コード例 #4
0
ファイル: sfJointReg.py プロジェクト: stnava/ANTsPyDocker
networks = powers_areal_mni_itk['SystemName'].unique()
dfnpts = np.where( powers_areal_mni_itk['SystemName'] == networks[5] )
dfnImg = ants.mask_image(  ptImg, ptImg, level = dfnpts[0].tolist(), binarize=False )

# plot( und, ptImg, axis=3, window.overlay = range( ptImg ) )

bold2ch2 = ants.apply_transforms( ch2, und,  concatx2, whichtoinvert = ( True, False, True, False ) )


# Extracting canonical functional network maps
## preprocessing

csfAndWM = ( ants.threshold_image( boldseg, 1, 1 ) +
             ants.threshold_image( boldseg, 3, 3 ) ).morphology("erode",1)
bold = ants.image_read( boldfnsR )
boldList = ants.ndimage_to_list( bold )
avgBold = ants.get_average_of_timeseries( bold, range( 5 ) )
boldUndTX = ants.registration( und, avgBold, "SyN", regIterations = (15,4),
  synMetric = "CC", synSampling = 2, verbose = False )
boldUndTS = ants.apply_transforms( und, bold, boldUndTX['fwdtransforms'], imagetype = 3  )
motCorr = ants.motion_correction( boldUndTS, avgBold,
    type_of_transform="Rigid", verbose = True )
tr = ants.get_spacing( bold )[3]
highMotionTimes = np.where( motCorr['FD'] >= 0.5 )
goodtimes = np.where( motCorr['FD'] < 0.5 )
avgBold = ants.get_average_of_timeseries( motCorr['motion_corrected'], range( 5 ) )
#######################
nt = len(motCorr['FD'])
plt.plot(  range( nt ), motCorr['FD'] )
plt.show()
#################################################
コード例 #5
0
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import crop_image_center

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if return_probability_images == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    else:
        return (dkt_label_image)
コード例 #6
0
def deep_flash(t1,
               t2=None,
               do_preprocessing=True,
               use_rank_intensity=True,
               antsxnet_cache_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA2/CA3/CA4
    Label 14:  right DG/CA2/CA3/CA4
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * affine registration to the "deep flash" template.
    which is performed on the input images if do_preprocessing = True.

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    t2 : ANTsImage
        Optional 3-D T2-weighted brain image.  If specified, it is assumed to be
        pre-aligned to the t1.

    do_preprocessing : boolean
        See description above.

    use_rank_intensity : boolean
        If false, use histogram matching with cropped template ROI.  Otherwise,
        use a rank intensity transform on the cropped ROI.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label and foreground.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import brain_extraction

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Options temporarily taken from the user
    #
    ################################

    # use_hierarchical_parcellation : boolean
    #     If True, use u-net model with additional outputs of the medial temporal lobe
    #     region, hippocampal, and entorhinal/perirhinal/parahippocampal regions.  Otherwise
    #     the only additional output is the medial temporal lobe.
    #
    # use_contralaterality : boolean
    #     Use both hemispherical models to also predict the corresponding contralateral
    #     segmentation and use both sets of priors to produce the results.  Mainly used
    #     for debugging.

    use_hierarchical_parcellation = True
    use_contralaterality = True

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    t1_mask = None
    t1_preprocessed_flipped = None
    t1_template = ants.image_read(
        get_antsxnet_data("deepFlashTemplateT1SkullStripped"))
    template_transforms = None
    if do_preprocessing:

        if verbose == True:
            print("Preprocessing T1.")

        # Brain extraction
        probability_mask = brain_extraction(
            t1_preprocessed,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        t1_preprocessed = t1_preprocessed * t1_mask

        # Do bias correction
        t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed,
                                                        t1_mask,
                                                        shrink_factor=4,
                                                        verbose=verbose)

        # Warp to template
        registration = ants.registration(
            fixed=t1_template,
            moving=t1_preprocessed,
            type_of_transform="antsRegistrationSyNQuickRepro[a]",
            verbose=verbose)
        template_transforms = dict(fwdtransforms=registration['fwdtransforms'],
                                   invtransforms=registration['invtransforms'])
        t1_preprocessed = registration['warpedmovout']

    if use_contralaterality:
        t1_preprocessed_array = t1_preprocessed.numpy()
        t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0)
        t1_preprocessed_flipped = ants.from_numpy(
            t1_preprocessed_array_flipped,
            origin=t1_preprocessed.origin,
            spacing=t1_preprocessed.spacing,
            direction=t1_preprocessed.direction)

    t2_preprocessed = t2
    t2_preprocessed_flipped = None
    t2_template = None
    if t2 is not None:
        t2_template = ants.image_read(
            get_antsxnet_data("deepFlashTemplateT2SkullStripped"))
        t2_template = ants.copy_image_info(t1_template, t2_template)
        if do_preprocessing:

            if verbose == True:
                print("Preprocessing T2.")

            # Brain extraction
            t2_preprocessed = t2_preprocessed * t1_mask

            # Do bias correction
            t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed,
                                                            t1_mask,
                                                            shrink_factor=4,
                                                            verbose=verbose)

            # Warp to template
            t2_preprocessed = ants.apply_transforms(
                fixed=t1_template,
                moving=t2_preprocessed,
                transformlist=template_transforms['fwdtransforms'],
                verbose=verbose)

        if use_contralaterality:
            t2_preprocessed_array = t2_preprocessed.numpy()
            t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array,
                                                    axis=0)
            t2_preprocessed_flipped = ants.from_numpy(
                t2_preprocessed_array_flipped,
                origin=t2_preprocessed.origin,
                spacing=t2_preprocessed.spacing,
                direction=t2_preprocessed.direction)

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
    image_size = (64, 64, 96)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    ################################
    #
    # Download spatial priors
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)
    for i in range(len(priors_image_list)):
        priors_image_list[i] = ants.copy_image_info(t1_preprocessed,
                                                    priors_image_list[i])

    labels_left = labels[1::2]
    priors_image_left_list = priors_image_list[1::2]
    probability_images_left = list()
    foreground_probability_images_left = list()
    lower_bound_left = (76, 74, 56)
    upper_bound_left = (140, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                    upper_bound_left)
    origin_left = tmp_cropped.origin

    spacing = tmp_cropped.spacing
    direction = tmp_cropped.direction

    t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left,
                                             upper_bound_left)
    t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min(
    )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0
    t2_template_roi_left = None
    if t2_template is not None:
        t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left,
                                                 upper_bound_left)
        t2_template_roi_left = (t2_template_roi_left -
                                t2_template_roi_left.min()) / (
                                    t2_template_roi_left.max() -
                                    t2_template_roi_left.min()) * 2.0 - 1.0

    labels_right = labels[2::2]
    priors_image_right_list = priors_image_list[2::2]
    probability_images_right = list()
    foreground_probability_images_right = list()
    lower_bound_right = (20, 74, 56)
    upper_bound_right = (84, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                    upper_bound_right)
    origin_right = tmp_cropped.origin

    t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right,
                                              upper_bound_right)
    t1_template_roi_right = (
        t1_template_roi_right - t1_template_roi_right.min()
    ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0
    t2_template_roi_right = None
    if t2_template is not None:
        t2_template_roi_right = ants.crop_indices(t2_template,
                                                  lower_bound_right,
                                                  upper_bound_right)
        t2_template_roi_right = (t2_template_roi_right -
                                 t2_template_roi_right.min()) / (
                                     t2_template_roi_right.max() -
                                     t2_template_roi_right.min()) * 2.0 - 1.0

    ################################
    #
    # Create model
    #
    ################################

    channel_size = 1 + len(labels_left)
    if t2 is not None:
        channel_size += 1

    number_of_classification_labels = 1 + len(labels_left)

    unet_model = create_unet_model_3d(
        (*image_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        mode="classification",
        number_of_filters=(32, 64, 96, 128, 256),
        convolution_kernel_size=(3, 3, 3),
        deconvolution_kernel_size=(2, 2, 2),
        dropout_rate=0.0,
        weight_decay=0)

    penultimate_layer = unet_model.layers[-2].output

    # medial temporal lobe
    output1 = Conv3D(
        filters=1,
        kernel_size=(1, 1, 1),
        activation='sigmoid',
        kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

    if use_hierarchical_parcellation:

        # EC, perirhinal, and parahippo.
        output2 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        # Hippocampus
        output3 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        unet_model = Model(
            inputs=unet_model.input,
            outputs=[unet_model.output, output1, output2, output3])
    else:
        unet_model = Model(inputs=unet_model.input,
                           outputs=[unet_model.output, output1])

    ################################
    #
    # Left:  build model and load weights
    #
    ################################

    network_name = 'deepFlashLeftT1'
    if t2 is not None:
        network_name = 'deepFlashLeftBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (left).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Left:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (left).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                   upper_bound_left)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_left, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_left, upper_bound_left)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_left, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left,
                                       upper_bound_left)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_left, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_left, upper_bound_left)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_left, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_left_list)):
        cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                          lower_bound_left, upper_bound_left)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_left))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_left)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_left, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                probability_images_left.append(probability_image)
            else:  # flipped
                probability_images_right.append(probability_image)

    ################################
    #
    # Left:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_left, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                foreground_probability_images_left.append(probability_image)
            else:
                foreground_probability_images_right.append(probability_image)

    ################################
    #
    # Right:  build model and load weights
    #
    ################################

    network_name = 'deepFlashRightT1'
    if t2 is not None:
        network_name = 'deepFlashRightBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (right).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Right:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (right).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                   upper_bound_right)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_right, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_right, upper_bound_right)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_right, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right,
                                       upper_bound_right)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_right, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_right,
                                           upper_bound_right)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_right, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_right_list)):
        cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                          lower_bound_right, upper_bound_right)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_right))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_right)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_right, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    probability_images_right[i] = (
                        probability_images_right[i] + probability_image) / 2
                else:
                    probability_images_right.append(probability_image)
            else:  # flipped
                probability_images_left[i] = (probability_images_left[i] +
                                              probability_image) / 2

    ################################
    #
    # Right:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_right, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    foreground_probability_images_right[
                        i - 1] = (foreground_probability_images_right[i - 1] +
                                  probability_image) / 2
                else:
                    foreground_probability_images_right.append(
                        probability_image)
            else:
                foreground_probability_images_left[
                    i - 1] = (foreground_probability_images_left[i - 1] +
                              probability_image) / 2

    ################################
    #
    # Combine priors
    #
    ################################

    probability_background_image = ants.image_clone(t1) * 0
    for i in range(1, len(probability_images_left)):
        probability_background_image += probability_images_left[i]
    for i in range(1, len(probability_images_right)):
        probability_background_image += probability_images_right[i]

    probability_images.append(probability_background_image * -1 + 1)
    for i in range(1, len(probability_images_left)):
        probability_images.append(probability_images_left[i])
        probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    foreground_probability_images = list()
    for i in range(len(foreground_probability_images_left)):
        foreground_probability_images.append(
            foreground_probability_images_left[i] +
            foreground_probability_images_right[i])

    return_dict = None
    if use_hierarchical_parcellation:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0],
            'other_region_probability_image':
            foreground_probability_images[1],
            'hippocampal_probability_image':
            foreground_probability_images[2]
        }
    else:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0]
        }

    return (return_dict)
コード例 #7
0
def deep_flash_deprecated(t1,
                          do_preprocessing=True,
                          do_per_hemisphere=True,
                          which_hemisphere_models="new",
                          antsxnet_cache_directory=None,
                          verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA3
    Label 14:  right DG/CA3
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    do_per_hemisphere : boolean
        If True, do prediction based on separate networks per hemisphere.  Otherwise,
        use the single network trained for both hemispheres.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    print("This function is deprecated.  Please update to deep_flash().")

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type="antsRegistrationSyNQuickRepro[a]",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)

    ################################
    #
    # Process left/right in same network
    #
    ################################

    if do_per_hemisphere == False:

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (160, 192, 160)

        unet_model = create_unet_model_3d(
            (*template_size, 1),
            number_of_outputs=len(labels),
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights.")

        weights_file_name = get_pretrained_network(
            "deepFlash", antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction.")

        cropped_image = pad_or_crop_image_to_size(t1_preprocessed,
                                                  template_size)

        batchX = np.expand_dims(cropped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        for i in range(len(labels)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images.append(decropped_image)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    else:

        ################################
        #
        # Left:  download spatial priors
        #
        ################################

        spatial_priors_left_file_name_path = get_antsxnet_data(
            "priorDeepFlashLeftLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_left = ants.image_read(
            spatial_priors_left_file_name_path)
        priors_image_left_list = ants.ndimage_to_list(spatial_priors_left)

        ################################
        #
        # Left:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_left = (0, 5, 7, 9, 11, 13, 15, 17)
        channel_size = 1 + len(labels_left)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashLeft16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashLeft16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_left),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights (left).")
        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Left:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (left).")

        cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0),
                                          (94, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_left_list)):
            cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                              (30, 51, 0), (94, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_left = list()
        for i in range(len(labels_left)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_left.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_left.append(decropped_image)

        ################################
        #
        # Right:  download spatial priors
        #
        ################################

        spatial_priors_right_file_name_path = get_antsxnet_data(
            "priorDeepFlashRightLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_right = ants.image_read(
            spatial_priors_right_file_name_path)
        priors_image_right_list = ants.ndimage_to_list(spatial_priors_right)

        ################################
        #
        # Right:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_right = (0, 6, 8, 10, 12, 14, 16, 18)
        channel_size = 1 + len(labels_right)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashRight16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashRight16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_right),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Right:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (right).")

        cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0),
                                          (152, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_right_list)):
            cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                              (88, 51, 0), (152, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_right = list()
        for i in range(len(labels_right)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_right.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_right.append(decropped_image)

        ################################
        #
        # Combine priors
        #
        ################################

        probability_background_image = ants.image_clone(t1) * 0
        for i in range(1, len(probability_images_left)):
            probability_background_image += probability_images_left[i]
        for i in range(1, len(probability_images_right)):
            probability_background_image += probability_images_right[i]

        probability_images.append(probability_background_image * -1 + 1)
        for i in range(1, len(probability_images_left)):
            probability_images.append(probability_images_left[i])
            probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    return_dict = {
        'segmentation_image': relabeled_image,
        'probability_images': probability_images
    }
    return (return_dict)
コード例 #8
0
def deep_atropos(t1,
                 do_preprocessing=True,
                 use_spatial_priors=1,
                 antsxnet_cache_directory=None,
                 verbose=False):
    """
    Six-tissue segmentation.

    Perform Atropos-style six tissue segmentation using deep learning.

    The labeling is as follows:
    Label 0 :  background
    Label 1 :  CSF
    Label 2 :  gray matter
    Label 3 :  white matter
    Label 4 :  deep gray matter
    Label 5 :  brain stem
    Label 6 :  cerebellum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    use_spatial_priors : integer
        Use MNI spatial tissue priors (0 or 1).  Currently, only '0' (no priors) and '1'
        (cerebellar prior only) are the only two options.  Default is 1.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_atropos(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type="antsRegistrationSyNQuickRepro[a]",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Build model and load weights
    #
    ################################

    patch_size = (112, 112, 112)
    stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                     t1_preprocessed.shape[1] - patch_size[1],
                     t1_preprocessed.shape[2] - patch_size[2])

    classes = ("background", "csf", "gray matter", "white matter",
               "deep gray matter", "brain stem", "cerebellum")

    mni_priors = None
    channel_size = 1
    if use_spatial_priors != 0:
        mni_priors = ants.ndimage_to_list(
            ants.image_read(
                get_antsxnet_data(
                    "croppedMni152Priors",
                    antsxnet_cache_directory=antsxnet_cache_directory)))
        for i in range(len(mni_priors)):
            mni_priors[i] = ants.copy_image_info(t1_preprocessed,
                                                 mni_priors[i])
        channel_size = 2

    unet_model = create_unet_model_3d((*patch_size, channel_size),
                                      number_of_outputs=len(classes),
                                      mode="classification",
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    if verbose == True:
        print("DeepAtropos:  retrieving model weights.")

    weights_file_name = ''
    if use_spatial_priors == 0:
        weights_file_name = get_pretrained_network(
            "sixTissueOctantBrainSegmentation",
            antsxnet_cache_directory=antsxnet_cache_directory)
    elif use_spatial_priors == 1:
        weights_file_name = get_pretrained_network(
            "sixTissueOctantBrainSegmentationWithPriors1",
            antsxnet_cache_directory=antsxnet_cache_directory)
    else:
        raise ValueError("use_spatial_priors must be a 0 or 1")
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    t1_preprocessed = (t1_preprocessed -
                       t1_preprocessed.mean()) / t1_preprocessed.std()
    image_patches = extract_image_patches(t1_preprocessed,
                                          patch_size=patch_size,
                                          max_number_of_patches="all",
                                          stride_length=stride_length,
                                          return_as_array=True)
    batchX = np.zeros((*image_patches.shape, channel_size))
    batchX[:, :, :, :, 0] = image_patches
    if channel_size > 1:
        prior_patches = extract_image_patches(mni_priors[6],
                                              patch_size=patch_size,
                                              max_number_of_patches="all",
                                              stride_length=stride_length,
                                              return_as_array=True)
        batchX[:, :, :, :, 1] = prior_patches

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    probability_images = list()
    for i in range(len(classes)):
        if verbose == True:
            print("Reconstructing image", classes[i])
        reconstructed_image = reconstruct_image_from_patches(
            predicted_data[:, :, :, :, i],
            domain_image=t1_preprocessed,
            stride_length=stride_length)

        if do_preprocessing == True:
            probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            probability_images.append(reconstructed_image)

    image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    return_dict = {
        'segmentation_image': segmentation_image,
        'probability_images': probability_images
    }
    return (return_dict)