コード例 #1
0
 def crop(self, img_dir, label_dir):
     img_pix = ants.image_read(img_dir)
     img = img_pix.numpy()
     label_pix = ants.image_read(label_dir)
     label = label_pix.numpy()
     # label_nib = nib.load('Label/N001_Q_QSM_SyNAggro_moved_ROIs.nii.gz')
     # np_data = np.array(label_nib.dataobj)
     # # print(np.where(np_data>3))
     label_bounding_box_0 = np.where(label > 0)  # label =1 to 10
     label_bounding_box_11 = np.where(label < 11)
     print(label_bounding_box_11)
     x_min = np.min(label_bounding_box_0[0])
     y_min = np.min(label_bounding_box_0[1])
     z_min = np.min(label_bounding_box_0[2])
     x_max = np.max(label_bounding_box_0[0])
     y_max = np.max(label_bounding_box_0[1])
     z_max = np.max(label_bounding_box_0[2])
     print(x_min, y_min, z_min, x_max, y_max, z_max, img_dir)
     cropped_label = ants.crop_indices(label_pix, (x_min, y_min, z_min),
                                       (x_max, y_max, z_max))
     #  cropped = ants.crop_image(img,label,3)
     cropped_img = ants.crop_indices(img_pix, (x_min, y_min, z_min),
                                     (x_max, y_max, z_max))
     if not os.path.exists("data/QSM_masked_cropped/"):
         os.makedirs("data/QSM_masked_cropped/")
     if not os.path.exists("data/label_cropped/"):
         os.makedirs("data/label_cropped/")
     img_name = re.split(r'[/.]', img_dir)[2]
     label_name = re.split(r'[/.]', label_dir)[2]
     ants.image_write(
         cropped_img,
         "data/QSM_masked_cropped/" + img_name + "_cropped.nii.gz")
     ants.image_write(
         cropped_label,
         "data/label_cropped/" + label_name + "_cropped.nii.gz")
コード例 #2
0
def cde_mot_rigidreg(fixed,
                     images,
                     Fimg='..',
                     spacing=list([1.6, 1.6, 8]),
                     crop=False,
                     saveprog=False,
                     savesuff='',
                     savedir='.'):
    import ants
    import os
    import datetime
    import time

    print('>> Starting rigid registration <<')

    if type(fixed) == str: fi = ants.image_read(Fimg + os.sep + fixed)
    else: fi = fixed
    if crop: fi = ants.crop_indices(fi, [0, 0, 1], fi.shape)

    if savesuff: savesuff = '_' + savesuff

    mvd = []
    cnt = 0
    pct1 = len(images) / 100

    for i in images:
        cnt = cnt + 1

        if type(i) == str: img = ants.image_read(Fimg + os.sep + i)
        else: img = i

        img.set_spacing(spacing)
        if crop: img = ants.crop_indices(img, [0, 0, 1], img.shape)
        fi.set_spacing(spacing)

        # Actual ants registration step
        #-----------------------------------------------------------------------
        moved = ants.registration(fi, img, type_of_transform='QuickRigid')

        if saveprog:
            savename = savedir + os.sep + str(cnt).zfill(4) + savesuff + '.tif'
            ants.image_write(moved["warpedmovout"], savename)
            mvd.append(savename)

        else:
            mvd.append(moved["warpedmovout"])

        if cnt / pct1 % 5 == 0:  # < this doesn't work robustly
            ts = time.time()
            st = datetime.datetime.fromtimestamp(ts).strftime(
                '%Y-%m-%d %H:%M:%S')
            print('Completed ' + str((cnt) / pct1) + '% at ' + st)

    print('All done with rigid registration')
    if saveprog: print('The returned file contains tifs')

    return mvd
コード例 #3
0
def cde_mot_meancalc(imgs,
                     Fimg,
                     noimages=100,
                     delfirst=True,
                     crop=False,
                     plot='do'):
    import numpy as np
    import ants
    import os

    print('I found ' + str(len(imgs)) + ' images')

    # Load subsection of tifs
    #---------------------------------------------------------------------------
    maxno = np.min([len(imgs), noimages])
    loadi = np.linspace(0, len(imgs) - 1, maxno)
    loadi = loadi.astype(int)
    print('Of these I' 'm loading ' + str(maxno))
    if delfirst:
        loadi = np.delete(loadi, 0)
        print('I' 'm ignoring the first volume')

    # Load initial image for dimensions
    #---------------------------------------------------------------------------
    if type(imgs[0]) == str:
        templ = ants.image_read(Fimg + os.sep + imgs[0])

    elif type(imgs[0]) == ants.core.ants_image.ANTsImage:
        templ = imgs[0]

    if crop:
        templ = ants.crop_indices(templ, [0, 0, 1], templ.shape)

    mean_arr = np.multiply(templ.numpy(), 0)
    imglist = []

    for i in loadi:

        if type(imgs[0]) == str:
            img = ants.image_read(Fimg + os.sep + imgs[i])
        elif type(imgs[0]) == ants.core.ants_image.ANTsImage:
            img = imgs[i]
        if crop: img = ants.crop_indices(img, [0, 0, 1], img.shape)

        mean_arr = mean_arr + img.numpy() / maxno
        imglist.append(img)

    mimg = ants.from_numpy(mean_arr)
    if plot == 'do':
        ants.plot(mimg, axis=2, slices=range(mimg.shape[2]), figsize=3)

    return mimg, imglist
コード例 #4
0
 def crop_onpic(self, img_dir, label_dir, lowerind, upperind, path):
     img_pix = ants.image_read(img_dir)
     label_pix = ants.image_read(label_dir)
     cropped_img = ants.crop_indices(img_pix, lowerind, upperind)
     cropped_label = ants.crop_indices(label_pix, lowerind, upperind)
     img_name = re.split(r'[/.]', img_dir)[-3]
     label_name = re.split(r'[/.]', label_dir)[-3]
     ants.image_write(
         cropped_img,
         path + "QSM_masked_cropped/" + img_name + "_cropped.nii.gz")
     ants.image_write(
         cropped_label,
         path + "label_cropped/" + label_name + "_cropped.nii.gz")
コード例 #5
0
    def test_crop_indices_example(self):
        fi = ants.image_read(ants.get_ants_data("r16"))
        cropped = ants.crop_indices(fi, (10, 10), (100, 100))
        cropped = ants.smooth_image(cropped, 5)
        decropped = ants.decrop_image(cropped, fi)

        # image not float
        cropped = ants.crop_indices(fi.clone('unsigned int'), (10, 10),
                                    (100, 100))

        # image dim not equal to indices
        with self.assertRaises(Exception):
            cropped = ants.crop_indices(fi, (10, 10, 10), (100, 100))
            cropped = ants.crop_indices(fi, (10, 10), (100, 100, 100))
コード例 #6
0
def crop_image_center(image, crop_size):
    """
    Crop the center of an image.

    Arguments
    ---------
    image : ANTsImage
        Input image

    crop_size: n-D tuple (depending on dimensionality).
        Width, height, depth (if 3-D), and time (if 4-D) of crop region.

    Returns
    -------
    A list (or array) of patches.

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> cropped_image = crop_image_center(image, crop_size=(64, 64))
    """

    image_size = np.array(image.shape)

    if len(image_size) != len(crop_size):
        raise ValueError("crop_size does not match image size.")

    if (np.asarray(crop_size) > np.asarray(image_size)).any():
        raise ValueError("A crop_size dimension is larger than image_size.")

    start_index = (np.floor(
        0.5 * (np.asarray(image_size) - np.asarray(crop_size)))).astype(int)
    end_index = start_index + np.asarray(crop_size).astype(int)

    cropped_image = ants.crop_indices(
        ants.image_clone(image) * 1, start_index, end_index)

    return (cropped_image)
コード例 #7
0
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import crop_image_center

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if return_probability_images == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    else:
        return (dkt_label_image)
コード例 #8
0
centroid_indices = np.where(prediction_initial_stage == 1)
centroid = list()
centroid.append(int(np.mean(centroid_indices[0])))
centroid.append(int(np.mean(centroid_indices[1])))
centroid.append(int(np.mean(centroid_indices[2])))

lower = list()
lower.append(centroid[0] - int(0.5 * shape_refine_stage[0]))
lower.append(centroid[1] - int(0.5 * shape_refine_stage[1]))
lower.append(centroid[2] - int(0.5 * shape_refine_stage[2]))
upper = list()
upper.append(lower[0] + shape_refine_stage[0])
upper.append(lower[1] + shape_refine_stage[1])
upper.append(lower[2] + shape_refine_stage[2])

mask_trimmed = ants.crop_indices(mask_initial_stage, lower, upper)
image_trimmed = ants.crop_indices(image_resampled, lower, upper)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

# Build model and load weights for second pass
print("    Refine step 2: load weights.")
start_time = time.time()
model_refine_stage = antspynet.create_hippmapp3r_unet_model_3d(
    (*shape_refine_stage, 1), False)
weights_file_name = "./hippMapp3rRefineWeights.h5"

if not os.path.exists(weights_file_name):
    weights_file_name = antspynet.get_pretrained_network(
        "hippMapp3rRefine", weights_file_name)
コード例 #9
0
def deep_flash(t1,
               t2=None,
               do_preprocessing=True,
               use_rank_intensity=True,
               antsxnet_cache_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA2/CA3/CA4
    Label 14:  right DG/CA2/CA3/CA4
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * affine registration to the "deep flash" template.
    which is performed on the input images if do_preprocessing = True.

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    t2 : ANTsImage
        Optional 3-D T2-weighted brain image.  If specified, it is assumed to be
        pre-aligned to the t1.

    do_preprocessing : boolean
        See description above.

    use_rank_intensity : boolean
        If false, use histogram matching with cropped template ROI.  Otherwise,
        use a rank intensity transform on the cropped ROI.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label and foreground.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import brain_extraction

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Options temporarily taken from the user
    #
    ################################

    # use_hierarchical_parcellation : boolean
    #     If True, use u-net model with additional outputs of the medial temporal lobe
    #     region, hippocampal, and entorhinal/perirhinal/parahippocampal regions.  Otherwise
    #     the only additional output is the medial temporal lobe.
    #
    # use_contralaterality : boolean
    #     Use both hemispherical models to also predict the corresponding contralateral
    #     segmentation and use both sets of priors to produce the results.  Mainly used
    #     for debugging.

    use_hierarchical_parcellation = True
    use_contralaterality = True

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    t1_mask = None
    t1_preprocessed_flipped = None
    t1_template = ants.image_read(
        get_antsxnet_data("deepFlashTemplateT1SkullStripped"))
    template_transforms = None
    if do_preprocessing:

        if verbose == True:
            print("Preprocessing T1.")

        # Brain extraction
        probability_mask = brain_extraction(
            t1_preprocessed,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        t1_preprocessed = t1_preprocessed * t1_mask

        # Do bias correction
        t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed,
                                                        t1_mask,
                                                        shrink_factor=4,
                                                        verbose=verbose)

        # Warp to template
        registration = ants.registration(
            fixed=t1_template,
            moving=t1_preprocessed,
            type_of_transform="antsRegistrationSyNQuickRepro[a]",
            verbose=verbose)
        template_transforms = dict(fwdtransforms=registration['fwdtransforms'],
                                   invtransforms=registration['invtransforms'])
        t1_preprocessed = registration['warpedmovout']

    if use_contralaterality:
        t1_preprocessed_array = t1_preprocessed.numpy()
        t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0)
        t1_preprocessed_flipped = ants.from_numpy(
            t1_preprocessed_array_flipped,
            origin=t1_preprocessed.origin,
            spacing=t1_preprocessed.spacing,
            direction=t1_preprocessed.direction)

    t2_preprocessed = t2
    t2_preprocessed_flipped = None
    t2_template = None
    if t2 is not None:
        t2_template = ants.image_read(
            get_antsxnet_data("deepFlashTemplateT2SkullStripped"))
        t2_template = ants.copy_image_info(t1_template, t2_template)
        if do_preprocessing:

            if verbose == True:
                print("Preprocessing T2.")

            # Brain extraction
            t2_preprocessed = t2_preprocessed * t1_mask

            # Do bias correction
            t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed,
                                                            t1_mask,
                                                            shrink_factor=4,
                                                            verbose=verbose)

            # Warp to template
            t2_preprocessed = ants.apply_transforms(
                fixed=t1_template,
                moving=t2_preprocessed,
                transformlist=template_transforms['fwdtransforms'],
                verbose=verbose)

        if use_contralaterality:
            t2_preprocessed_array = t2_preprocessed.numpy()
            t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array,
                                                    axis=0)
            t2_preprocessed_flipped = ants.from_numpy(
                t2_preprocessed_array_flipped,
                origin=t2_preprocessed.origin,
                spacing=t2_preprocessed.spacing,
                direction=t2_preprocessed.direction)

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
    image_size = (64, 64, 96)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    ################################
    #
    # Download spatial priors
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)
    for i in range(len(priors_image_list)):
        priors_image_list[i] = ants.copy_image_info(t1_preprocessed,
                                                    priors_image_list[i])

    labels_left = labels[1::2]
    priors_image_left_list = priors_image_list[1::2]
    probability_images_left = list()
    foreground_probability_images_left = list()
    lower_bound_left = (76, 74, 56)
    upper_bound_left = (140, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                    upper_bound_left)
    origin_left = tmp_cropped.origin

    spacing = tmp_cropped.spacing
    direction = tmp_cropped.direction

    t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left,
                                             upper_bound_left)
    t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min(
    )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0
    t2_template_roi_left = None
    if t2_template is not None:
        t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left,
                                                 upper_bound_left)
        t2_template_roi_left = (t2_template_roi_left -
                                t2_template_roi_left.min()) / (
                                    t2_template_roi_left.max() -
                                    t2_template_roi_left.min()) * 2.0 - 1.0

    labels_right = labels[2::2]
    priors_image_right_list = priors_image_list[2::2]
    probability_images_right = list()
    foreground_probability_images_right = list()
    lower_bound_right = (20, 74, 56)
    upper_bound_right = (84, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                    upper_bound_right)
    origin_right = tmp_cropped.origin

    t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right,
                                              upper_bound_right)
    t1_template_roi_right = (
        t1_template_roi_right - t1_template_roi_right.min()
    ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0
    t2_template_roi_right = None
    if t2_template is not None:
        t2_template_roi_right = ants.crop_indices(t2_template,
                                                  lower_bound_right,
                                                  upper_bound_right)
        t2_template_roi_right = (t2_template_roi_right -
                                 t2_template_roi_right.min()) / (
                                     t2_template_roi_right.max() -
                                     t2_template_roi_right.min()) * 2.0 - 1.0

    ################################
    #
    # Create model
    #
    ################################

    channel_size = 1 + len(labels_left)
    if t2 is not None:
        channel_size += 1

    number_of_classification_labels = 1 + len(labels_left)

    unet_model = create_unet_model_3d(
        (*image_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        mode="classification",
        number_of_filters=(32, 64, 96, 128, 256),
        convolution_kernel_size=(3, 3, 3),
        deconvolution_kernel_size=(2, 2, 2),
        dropout_rate=0.0,
        weight_decay=0)

    penultimate_layer = unet_model.layers[-2].output

    # medial temporal lobe
    output1 = Conv3D(
        filters=1,
        kernel_size=(1, 1, 1),
        activation='sigmoid',
        kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

    if use_hierarchical_parcellation:

        # EC, perirhinal, and parahippo.
        output2 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        # Hippocampus
        output3 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        unet_model = Model(
            inputs=unet_model.input,
            outputs=[unet_model.output, output1, output2, output3])
    else:
        unet_model = Model(inputs=unet_model.input,
                           outputs=[unet_model.output, output1])

    ################################
    #
    # Left:  build model and load weights
    #
    ################################

    network_name = 'deepFlashLeftT1'
    if t2 is not None:
        network_name = 'deepFlashLeftBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (left).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Left:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (left).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                   upper_bound_left)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_left, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_left, upper_bound_left)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_left, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left,
                                       upper_bound_left)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_left, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_left, upper_bound_left)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_left, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_left_list)):
        cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                          lower_bound_left, upper_bound_left)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_left))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_left)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_left, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                probability_images_left.append(probability_image)
            else:  # flipped
                probability_images_right.append(probability_image)

    ################################
    #
    # Left:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_left, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                foreground_probability_images_left.append(probability_image)
            else:
                foreground_probability_images_right.append(probability_image)

    ################################
    #
    # Right:  build model and load weights
    #
    ################################

    network_name = 'deepFlashRightT1'
    if t2 is not None:
        network_name = 'deepFlashRightBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (right).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Right:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (right).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                   upper_bound_right)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_right, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_right, upper_bound_right)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_right, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right,
                                       upper_bound_right)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_right, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_right,
                                           upper_bound_right)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_right, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_right_list)):
        cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                          lower_bound_right, upper_bound_right)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_right))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_right)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_right, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    probability_images_right[i] = (
                        probability_images_right[i] + probability_image) / 2
                else:
                    probability_images_right.append(probability_image)
            else:  # flipped
                probability_images_left[i] = (probability_images_left[i] +
                                              probability_image) / 2

    ################################
    #
    # Right:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_right, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    foreground_probability_images_right[
                        i - 1] = (foreground_probability_images_right[i - 1] +
                                  probability_image) / 2
                else:
                    foreground_probability_images_right.append(
                        probability_image)
            else:
                foreground_probability_images_left[
                    i - 1] = (foreground_probability_images_left[i - 1] +
                              probability_image) / 2

    ################################
    #
    # Combine priors
    #
    ################################

    probability_background_image = ants.image_clone(t1) * 0
    for i in range(1, len(probability_images_left)):
        probability_background_image += probability_images_left[i]
    for i in range(1, len(probability_images_right)):
        probability_background_image += probability_images_right[i]

    probability_images.append(probability_background_image * -1 + 1)
    for i in range(1, len(probability_images_left)):
        probability_images.append(probability_images_left[i])
        probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    foreground_probability_images = list()
    for i in range(len(foreground_probability_images_left)):
        foreground_probability_images.append(
            foreground_probability_images_left[i] +
            foreground_probability_images_right[i])

    return_dict = None
    if use_hierarchical_parcellation:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0],
            'other_region_probability_image':
            foreground_probability_images[1],
            'hippocampal_probability_image':
            foreground_probability_images[2]
        }
    else:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0]
        }

    return (return_dict)
コード例 #10
0
def deep_flash_deprecated(t1,
                          do_preprocessing=True,
                          do_per_hemisphere=True,
                          which_hemisphere_models="new",
                          antsxnet_cache_directory=None,
                          verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA3
    Label 14:  right DG/CA3
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    do_per_hemisphere : boolean
        If True, do prediction based on separate networks per hemisphere.  Otherwise,
        use the single network trained for both hemispheres.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    print("This function is deprecated.  Please update to deep_flash().")

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type="antsRegistrationSyNQuickRepro[a]",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)

    ################################
    #
    # Process left/right in same network
    #
    ################################

    if do_per_hemisphere == False:

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (160, 192, 160)

        unet_model = create_unet_model_3d(
            (*template_size, 1),
            number_of_outputs=len(labels),
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights.")

        weights_file_name = get_pretrained_network(
            "deepFlash", antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction.")

        cropped_image = pad_or_crop_image_to_size(t1_preprocessed,
                                                  template_size)

        batchX = np.expand_dims(cropped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        for i in range(len(labels)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images.append(decropped_image)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    else:

        ################################
        #
        # Left:  download spatial priors
        #
        ################################

        spatial_priors_left_file_name_path = get_antsxnet_data(
            "priorDeepFlashLeftLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_left = ants.image_read(
            spatial_priors_left_file_name_path)
        priors_image_left_list = ants.ndimage_to_list(spatial_priors_left)

        ################################
        #
        # Left:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_left = (0, 5, 7, 9, 11, 13, 15, 17)
        channel_size = 1 + len(labels_left)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashLeft16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashLeft16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_left),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights (left).")
        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Left:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (left).")

        cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0),
                                          (94, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_left_list)):
            cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                              (30, 51, 0), (94, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_left = list()
        for i in range(len(labels_left)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_left.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_left.append(decropped_image)

        ################################
        #
        # Right:  download spatial priors
        #
        ################################

        spatial_priors_right_file_name_path = get_antsxnet_data(
            "priorDeepFlashRightLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_right = ants.image_read(
            spatial_priors_right_file_name_path)
        priors_image_right_list = ants.ndimage_to_list(spatial_priors_right)

        ################################
        #
        # Right:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_right = (0, 6, 8, 10, 12, 14, 16, 18)
        channel_size = 1 + len(labels_right)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashRight16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashRight16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_right),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Right:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (right).")

        cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0),
                                          (152, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_right_list)):
            cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                              (88, 51, 0), (152, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_right = list()
        for i in range(len(labels_right)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_right.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_right.append(decropped_image)

        ################################
        #
        # Combine priors
        #
        ################################

        probability_background_image = ants.image_clone(t1) * 0
        for i in range(1, len(probability_images_left)):
            probability_background_image += probability_images_left[i]
        for i in range(1, len(probability_images_right)):
            probability_background_image += probability_images_right[i]

        probability_images.append(probability_background_image * -1 + 1)
        for i in range(1, len(probability_images_left)):
            probability_images.append(probability_images_left[i])
            probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    return_dict = {
        'segmentation_image': relabeled_image,
        'probability_images': probability_images
    }
    return (return_dict)
コード例 #11
0
 def crop_mask(self, mask_dir, lowerind, upperind):
     mask_pix = ants.image_read(mask_dir)
     cropped_mask = ants.crop_indices(mask_pix, lowerind, upperind)
     ants.image_write(cropped_mask,
                      "data/mask/mask_of_N037_QSM_cropped.nii.gz")
     print('mask cropped')
コード例 #12
0
def hippmapp3r_segmentation(t1,
                            do_preprocessing=True,
                            antsxnet_cache_directory=None,
                            verbose=False):
    """
    Perform HippMapp3r (hippocampal) segmentation described in

     https://www.ncbi.nlm.nih.gov/pubmed/31609046

    with models and architecture ported from

    https://github.com/mgoubran/HippMapp3r

    Additional documentation and attribution resources found at

    https://hippmapp3r.readthedocs.io/en/latest/

    Preprocessing consists of:
       * n4 bias correction and
       * brain extraction
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        input image

    do_preprocessing : boolean
        See description above.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs labeled hippocampal image.

    Example
    -------
    >>> mask = hippmapp3r_segmentation(t1)
    """

    from ..architectures import create_hippmapp3r_unet_model_3d
    from ..utilities import preprocess_brain_image
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if verbose == True:
        print("*************  Preprocessing  ***************")
        print("")

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality="t1",
            template=None,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    if verbose == True:
        print("*************  Initial stage segmentation  ***************")
        print("")

    # Normalize to mprage_hippmapp3r space
    if verbose == True:
        print("    HippMapp3r: template normalization.")

    template_file_name_path = get_antsxnet_data(
        "mprage_hippmapp3r", antsxnet_cache_directory=antsxnet_cache_directory)
    template_image = ants.image_read(template_file_name_path)

    registration = ants.registration(
        fixed=template_image,
        moving=t1_preprocessed,
        type_of_transform="antsRegistrationSyNQuickRepro[t]",
        verbose=verbose)
    image = registration['warpedmovout']
    transforms = dict(fwdtransforms=registration['fwdtransforms'],
                      invtransforms=registration['invtransforms'])

    # Threshold at 10th percentile of non-zero voxels in "robust range (fslmaths)"
    if verbose == True:
        print("    HippMapp3r: threshold.")

    image_array = image.numpy()
    image_robust_range = np.quantile(image_array[np.where(image_array != 0)],
                                     (0.02, 0.98))
    threshold_value = 0.10 * (image_robust_range[1] -
                              image_robust_range[0]) + image_robust_range[0]
    thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0,
                                            1)
    thresholded_image = image * thresholded_mask

    # Standardize image
    if verbose == True:
        print("    HippMapp3r: standardize.")

    mean_image = np.mean(thresholded_image[thresholded_mask == 1])
    sd_image = np.std(thresholded_image[thresholded_mask == 1])
    image_normalized = (image - mean_image) / sd_image
    image_normalized = image_normalized * thresholded_mask

    # Trim and resample image
    if verbose == True:
        print("    HippMapp3r: trim and resample to (160, 160, 128).")

    image_cropped = ants.crop_image(image_normalized, thresholded_mask, 1)
    shape_initial_stage = (160, 160, 128)
    image_resampled = ants.resample_image(image_cropped,
                                          shape_initial_stage,
                                          use_voxels=True,
                                          interp_type=1)

    if verbose == True:
        print("    HippMapp3r: generate first network and download weights.")

    model_initial_stage = create_hippmapp3r_unet_model_3d(
        (*shape_initial_stage, 1), do_first_network=True)

    initial_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rInitial", antsxnet_cache_directory=antsxnet_cache_directory)
    model_initial_stage.load_weights(initial_stage_weights_file_name)

    if verbose == True:
        print("    HippMapp3r: prediction.")

    data_initial_stage = np.expand_dims(image_resampled.numpy(), axis=0)
    data_initial_stage = np.expand_dims(data_initial_stage, axis=-1)
    mask_array = model_initial_stage.predict(data_initial_stage,
                                             verbose=verbose)
    mask_image_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(np.squeeze(mask_array)))
    mask_image = ants.resample_image(mask_image_resampled,
                                     image.shape,
                                     use_voxels=True,
                                     interp_type=0)
    mask_image[mask_image >= 0.5] = 1
    mask_image[mask_image < 0.5] = 0

    #########################################
    #
    # Perform refined (stage 2) segmentation
    #

    if verbose == True:
        print("")
        print("")
        print("*************  Refine stage segmentation  ***************")
        print("")

    mask_array = np.squeeze(mask_array)
    centroid_indices = np.where(mask_array == 1)
    centroid = np.zeros((3, ))
    centroid[0] = centroid_indices[0].mean()
    centroid[1] = centroid_indices[1].mean()
    centroid[2] = centroid_indices[2].mean()

    shape_refine_stage = (112, 112, 64)
    lower = (np.floor(centroid - 0.5 * np.array(shape_refine_stage)) -
             1).astype(int)
    upper = (lower + np.array(shape_refine_stage)).astype(int)

    image_trimmed = ants.crop_indices(image_resampled, lower.astype(int),
                                      upper.astype(int))

    if verbose == True:
        print("    HippMapp3r: generate second network and download weights.")

    model_refine_stage = create_hippmapp3r_unet_model_3d(
        (*shape_refine_stage, 1), do_first_network=False)

    refine_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rRefine", antsxnet_cache_directory=antsxnet_cache_directory)
    model_refine_stage.load_weights(refine_stage_weights_file_name)

    data_refine_stage = np.expand_dims(image_trimmed.numpy(), axis=0)
    data_refine_stage = np.expand_dims(data_refine_stage, axis=-1)

    if verbose == True:
        print("    HippMapp3r: Monte Carlo iterations (SpatialDropout).")

    number_of_mci_iterations = 30
    prediction_refine_stage = np.zeros(shape_refine_stage)
    for i in range(number_of_mci_iterations):
        tf.random.set_seed(i)
        if verbose == True:
            print("        Monte Carlo iteration", i + 1, "out of",
                  number_of_mci_iterations)
        prediction_refine_stage = \
            (np.squeeze(model_refine_stage.predict(data_refine_stage, verbose=verbose)) + \
             i * prediction_refine_stage ) / (i + 1)

    prediction_refine_stage_array = np.zeros(image_resampled.shape)
    prediction_refine_stage_array[lower[0]:upper[0], lower[1]:upper[1],
                                  lower[2]:upper[2]] = prediction_refine_stage
    probability_mask_refine_stage_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(prediction_refine_stage_array))

    segmentation_image_resampled = ants.label_clusters(ants.threshold_image(
        probability_mask_refine_stage_resampled, 0.0, 0.5, 0, 1),
                                                       min_cluster_size=10)
    segmentation_image_resampled[segmentation_image_resampled > 2] = 0
    geom = ants.label_geometry_measures(segmentation_image_resampled)
    if len(geom['VolumeInMillimeters']) < 2:
        raise ValueError("Error: left and right hippocampus not found.")

    if geom['Centroid_x'][0] < geom['Centroid_x'][1]:
        segmentation_image_resampled[segmentation_image_resampled == 1] = 3
        segmentation_image_resampled[segmentation_image_resampled == 2] = 1
        segmentation_image_resampled[segmentation_image_resampled == 3] = 2

    segmentation_image = ants.apply_transforms(
        fixed=t1,
        moving=segmentation_image_resampled,
        transformlist=transforms['invtransforms'],
        whichtoinvert=[True],
        interpolator="genericLabel",
        verbose=verbose)

    return (segmentation_image)