コード例 #1
0
ファイル: test_segmentation.py プロジェクト: zwmJohn/ANTsPy
    def test_example(self):
        ref = ants.image_read(ants.get_ants_data('r16'))
        ref = ants.resample_image(ref, (50, 50), 1, 0)
        ref = ants.iMath(ref, 'Normalize')
        mi = ants.image_read(ants.get_ants_data('r27'))
        mi2 = ants.image_read(ants.get_ants_data('r30'))
        mi3 = ants.image_read(ants.get_ants_data('r62'))
        mi4 = ants.image_read(ants.get_ants_data('r64'))
        mi5 = ants.image_read(ants.get_ants_data('r85'))
        refmask = ants.get_mask(ref)
        refmask = ants.iMath(refmask, 'ME', 2)  # just to speed things up
        ilist = [mi, mi2, mi3, mi4, mi5]
        seglist = [None] * len(ilist)
        for i in range(len(ilist)):
            ilist[i] = ants.iMath(ilist[i], 'Normalize')
            mytx = ants.registration(fixed=ref,
                                     moving=ilist[i],
                                     typeofTransform=('Affine'))
            mywarpedimage = ants.apply_transforms(
                fixed=ref,
                moving=ilist[i],
                transformlist=mytx['fwdtransforms'])
            ilist[i] = mywarpedimage
            seg = ants.threshold_image(ilist[i], 'Otsu', 3)
            seglist[i] = (seg) + ants.threshold_image(seg, 1, 3).morphology(
                operation='dilate', radius=3)

        r = 2
        pp = ants.joint_label_fusion(ref,
                                     refmask,
                                     ilist,
                                     r_search=2,
                                     label_list=seglist,
                                     rad=[r] * ref.dimension)
        pp = ants.joint_label_fusion(ref, refmask, ilist, r_search=2, rad=2)
コード例 #2
0
ファイル: test_viz.py プロジェクト: zwmJohn/ANTsPy
 def test_vol_example(self):
     ch2i = ants.image_read(ants.get_ants_data("mni"))
     ch2seg = ants.threshold_image(ch2i, "Otsu", 3)
     wm = ants.threshold_image(ch2seg, 3, 3)
     kimg = ants.weingarten_image_curvature(ch2i, 1.5).smooth_image(1)
     rp = [(90, 180, 90), (90, 180, 270), (90, 180, 180)]
     filename = mktemp(suffix='.png')
     result = ants.vol(wm, [kimg],
                       quantlimits=(0.01, 0.99),
                       filename=filename)
コード例 #3
0
ファイル: test_viz.py プロジェクト: zwmJohn/ANTsPy
 def test_surf_example(self):
     ch2i = ants.image_read(ants.get_ants_data("ch2"))
     ch2seg = ants.threshold_image(ch2i, "Otsu", 3)
     wm = ants.threshold_image(ch2seg, 3, 3)
     wm2 = wm.smooth_image(1).threshold_image(0.5, 1e15)
     kimg = ants.weingarten_image_curvature(ch2i, 1.5).smooth_image(1)
     wmz = wm2.iMath("MD", 3)
     rp = [(90, 180, 90), (90, 180, 270), (90, 180, 180)]
     filename = mktemp(suffix='.png')
     ants.surf(x=wm2,
               y=[kimg],
               z=[wmz],
               inflation_factor=255,
               overlay_limits=(-0.3, 0.3),
               verbose=True,
               rotation_params=rp,
               filename=filename)
コード例 #4
0
    def __brain_extraction(self):
        # https://antsx.github.io/ANTsPyNet/docs/build/html/utilities.html#applications
        if self._modality == "t1":
            image = self._t1_n4
        elif self._modality == "t2":
            image = self._t2_n4
        else:
            sys.exit("invalid contrast specified for brain extraction")

        prob = brain_extraction(image, modality=self._modality)
        # mask can be obtained as:
        mask = ants.threshold_image(prob,
                                    low_thresh=0.5,
                                    high_thresh=1.0,
                                    inval=1,
                                    outval=0,
                                    binary=True)
        return mask
コード例 #5
0
ファイル: test_segmentation.py プロジェクト: stnava/ANTsPy
 def test_example(self):
     fi = ants.image_read(ants.get_ants_data('r16'))
     seg = ants.kmeans_segmentation(fi, 3)
     mask = ants.threshold_image(seg['segmentation'], 1, 1e15)
     priorseg = ants.prior_based_segmentation(fi, seg['probabilityimages'],
                                              mask, 0.25, 0.1, 3)
コード例 #6
0
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        do_lobar_parcellation=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Performing the lobar parcellation is based on the FreeSurfer division
    described here:

    See https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation

    Frontal lobe:
    Label 1002:  left caudal anterior cingulate
    Label 1003:  left caudal middle frontal
    Label 1012:  left lateral orbitofrontal
    Label 1014:  left medial orbitofrontal
    Label 1017:  left paracentral
    Label 1018:  left pars opercularis
    Label 1019:  left pars orbitalis
    Label 1020:  left pars triangularis
    Label 1024:  left precentral
    Label 1026:  left rostral anterior cingulate
    Label 1027:  left rostral middle frontal
    Label 1028:  left superior frontal
    Label 2002:  right caudal anterior cingulate
    Label 2003:  right caudal middle frontal
    Label 2012:  right lateral orbitofrontal
    Label 2014:  right medial orbitofrontal
    Label 2017:  right paracentral
    Label 2018:  right pars opercularis
    Label 2019:  right pars orbitalis
    Label 2020:  right pars triangularis
    Label 2024:  right precentral
    Label 2026:  right rostral anterior cingulate
    Label 2027:  right rostral middle frontal
    Label 2028:  right superior frontal

    Parietal:
    Label 1008:  left inferior parietal
    Label 1010:  left isthmus cingulate
    Label 1022:  left postcentral
    Label 1023:  left posterior cingulate
    Label 1025:  left precuneus
    Label 1029:  left superior parietal
    Label 1031:  left supramarginal
    Label 2008:  right inferior parietal
    Label 2010:  right isthmus cingulate
    Label 2022:  right postcentral
    Label 2023:  right posterior cingulate
    Label 2025:  right precuneus
    Label 2029:  right superior parietal
    Label 2031:  right supramarginal

    Temporal:
    Label 1006:  left entorhinal
    Label 1007:  left fusiform
    Label 1009:  left inferior temporal
    Label 1015:  left middle temporal
    Label 1016:  left parahippocampal
    Label 1030:  left superior temporal
    Label 1034:  left transverse temporal
    Label 2006:  right entorhinal
    Label 2007:  right fusiform
    Label 2009:  right inferior temporal
    Label 2015:  right middle temporal
    Label 2016:  right parahippocampal
    Label 2030:  right superior temporal
    Label 2034:  right transverse temporal

    Occipital:
    Label 1005:  left cuneus
    Label 1011:  left lateral occipital
    Label 1013:  left lingual
    Label 1021:  left pericalcarine
    Label 2005:  right cuneus
    Label 2011:  right lateral occipital
    Label 2013:  right lingual
    Label 2021:  right pericalcarine

    Other outer labels:
    Label 1035:  left insula
    Label 2035:  right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    do_lobar_parcellation : boolean
        Perform lobar parcellation (also divided by hemisphere).

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> dkt = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import deep_atropos

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    template_transform_type = "antsRegistrationSyNQuickRepro[a]"
    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type=template_transform_type,
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if do_lobar_parcellation:

        if verbose == True:
            print("Doing lobar parcellation.")

        ################################
        #
        # Lobar/hemisphere parcellation
        #
        ################################

        # Consolidate lobar cortical labels

        if verbose == True:
            print("   Consolidating cortical labels.")

        frontal_labels = (1002, 1003, 1012, 1014, 1017, 1018, 1019, 1020, 1024,
                          1026, 1027, 1028, 2002, 2003, 2012, 2014, 2017, 2018,
                          2019, 2020, 2024, 2026, 2027, 2028)
        parietal_labels = (1008, 1010, 1022, 1023, 1025, 1029, 1031, 2008,
                           2010, 2022, 2023, 2025, 2029, 2031)
        temporal_labels = (1006, 1007, 1009, 1015, 1016, 1030, 1034, 2006,
                           2007, 2009, 2015, 2016, 2030, 2034)
        occipital_labels = (1005, 1011, 1013, 1021, 2005, 2011, 2013, 2021)

        lobar_labels = list()
        lobar_labels.append(frontal_labels)
        lobar_labels.append(parietal_labels)
        lobar_labels.append(temporal_labels)
        lobar_labels.append(occipital_labels)

        dkt_lobes = ants.image_clone(dkt_label_image)
        dkt_lobes[dkt_lobes < 1000] = 0

        for i in range(len(lobar_labels)):
            for j in range(len(lobar_labels[i])):
                dkt_lobes[dkt_lobes == lobar_labels[i][j]] = i + 1

        dkt_lobes[dkt_lobes > len(lobar_labels)] = 0

        six_tissue = deep_atropos(
            t1_preprocessed,
            do_preprocessing=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        atropos_seg = six_tissue['segmentation_image']
        if do_preprocessing == True:
            atropos_seg = ants.apply_transforms(
                fixed=t1,
                moving=atropos_seg,
                transformlist=t1_preprocessing['template_transforms']
                ['invtransforms'],
                whichtoinvert=[True],
                interpolator="genericLabel",
                verbose=verbose)

        brain_mask = ants.image_clone(atropos_seg)
        brain_mask[brain_mask == 1 or brain_mask == 5 or brain_mask == 6] = 0
        brain_mask = ants.threshold_image(brain_mask, 0, 0, 0, 1)

        lobar_parcellation = ants.iMath(brain_mask,
                                        "PropagateLabelsThroughMask",
                                        brain_mask * dkt_lobes)

        lobar_parcellation[atropos_seg == 5] = 5
        lobar_parcellation[atropos_seg == 6] = 6

        # Do left/right

        if verbose == True:
            print("   Doing left/right hemispheres.")

        left_labels = (*tuple(range(4, 8)), *tuple(range(10, 14)), 17, 18, 25,
                       26, 28, 30, 91, 1002, 1003, *tuple(range(1005, 1032)),
                       1034, 1035)
        right_labels = (*tuple(range(43, 47)), *tuple(range(49, 55)), 57, 58,
                        60, 62, 92, 2002, 2003, *tuple(range(2005, 2032)),
                        2034, 2035)

        hemisphere_labels = list()
        hemisphere_labels.append(left_labels)
        hemisphere_labels.append(right_labels)

        dkt_hemispheres = ants.image_clone(dkt_label_image)

        for i in range(len(hemisphere_labels)):
            for j in range(len(hemisphere_labels[i])):
                dkt_hemispheres[dkt_hemispheres == hemisphere_labels[i]
                                [j]] = i + 1

        dkt_hemispheres[dkt_hemispheres > 2] = 0

        atropos_brain_mask = ants.threshold_image(atropos_seg, 0, 0, 0, 1)
        hemisphere_parcellation = ants.iMath(
            atropos_brain_mask, "PropagateLabelsThroughMask",
            atropos_brain_mask * dkt_hemispheres)

        # The following contains a bug somewhere as only the latter condition is seen.
        # Need to fix it.
        #
        # for i in range(6):
        #     lobar_parcellation[lobar_parcellation == (i + 1) and hemisphere_parcellation == 2] = 6 + i + 1

        hemisphere_parcellation *= ants.threshold_image(
            lobar_parcellation, 0, 0, 0, 1)
        hemisphere_parcellation[hemisphere_parcellation == 1] = 0
        hemisphere_parcellation[hemisphere_parcellation == 2] = 1
        hemisphere_parcellation *= 6
        lobar_parcellation += hemisphere_parcellation

    if return_probability_images == True and do_lobar_parcellation == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'lobar_parcellation': lobar_parcellation,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    elif return_probability_images == True and do_lobar_parcellation == False:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    elif return_probability_images == False and do_lobar_parcellation == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'lobar_parcellation': lobar_parcellation
        }
        return (return_dict)
    else:
        return (dkt_label_image)
コード例 #7
0
def deep_flash(t1,
               t2=None,
               do_preprocessing=True,
               use_rank_intensity=True,
               antsxnet_cache_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA2/CA3/CA4
    Label 14:  right DG/CA2/CA3/CA4
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * affine registration to the "deep flash" template.
    which is performed on the input images if do_preprocessing = True.

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    t2 : ANTsImage
        Optional 3-D T2-weighted brain image.  If specified, it is assumed to be
        pre-aligned to the t1.

    do_preprocessing : boolean
        See description above.

    use_rank_intensity : boolean
        If false, use histogram matching with cropped template ROI.  Otherwise,
        use a rank intensity transform on the cropped ROI.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label and foreground.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import brain_extraction

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Options temporarily taken from the user
    #
    ################################

    # use_hierarchical_parcellation : boolean
    #     If True, use u-net model with additional outputs of the medial temporal lobe
    #     region, hippocampal, and entorhinal/perirhinal/parahippocampal regions.  Otherwise
    #     the only additional output is the medial temporal lobe.
    #
    # use_contralaterality : boolean
    #     Use both hemispherical models to also predict the corresponding contralateral
    #     segmentation and use both sets of priors to produce the results.  Mainly used
    #     for debugging.

    use_hierarchical_parcellation = True
    use_contralaterality = True

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    t1_mask = None
    t1_preprocessed_flipped = None
    t1_template = ants.image_read(
        get_antsxnet_data("deepFlashTemplateT1SkullStripped"))
    template_transforms = None
    if do_preprocessing:

        if verbose == True:
            print("Preprocessing T1.")

        # Brain extraction
        probability_mask = brain_extraction(
            t1_preprocessed,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        t1_preprocessed = t1_preprocessed * t1_mask

        # Do bias correction
        t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed,
                                                        t1_mask,
                                                        shrink_factor=4,
                                                        verbose=verbose)

        # Warp to template
        registration = ants.registration(
            fixed=t1_template,
            moving=t1_preprocessed,
            type_of_transform="antsRegistrationSyNQuickRepro[a]",
            verbose=verbose)
        template_transforms = dict(fwdtransforms=registration['fwdtransforms'],
                                   invtransforms=registration['invtransforms'])
        t1_preprocessed = registration['warpedmovout']

    if use_contralaterality:
        t1_preprocessed_array = t1_preprocessed.numpy()
        t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0)
        t1_preprocessed_flipped = ants.from_numpy(
            t1_preprocessed_array_flipped,
            origin=t1_preprocessed.origin,
            spacing=t1_preprocessed.spacing,
            direction=t1_preprocessed.direction)

    t2_preprocessed = t2
    t2_preprocessed_flipped = None
    t2_template = None
    if t2 is not None:
        t2_template = ants.image_read(
            get_antsxnet_data("deepFlashTemplateT2SkullStripped"))
        t2_template = ants.copy_image_info(t1_template, t2_template)
        if do_preprocessing:

            if verbose == True:
                print("Preprocessing T2.")

            # Brain extraction
            t2_preprocessed = t2_preprocessed * t1_mask

            # Do bias correction
            t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed,
                                                            t1_mask,
                                                            shrink_factor=4,
                                                            verbose=verbose)

            # Warp to template
            t2_preprocessed = ants.apply_transforms(
                fixed=t1_template,
                moving=t2_preprocessed,
                transformlist=template_transforms['fwdtransforms'],
                verbose=verbose)

        if use_contralaterality:
            t2_preprocessed_array = t2_preprocessed.numpy()
            t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array,
                                                    axis=0)
            t2_preprocessed_flipped = ants.from_numpy(
                t2_preprocessed_array_flipped,
                origin=t2_preprocessed.origin,
                spacing=t2_preprocessed.spacing,
                direction=t2_preprocessed.direction)

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
    image_size = (64, 64, 96)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    ################################
    #
    # Download spatial priors
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)
    for i in range(len(priors_image_list)):
        priors_image_list[i] = ants.copy_image_info(t1_preprocessed,
                                                    priors_image_list[i])

    labels_left = labels[1::2]
    priors_image_left_list = priors_image_list[1::2]
    probability_images_left = list()
    foreground_probability_images_left = list()
    lower_bound_left = (76, 74, 56)
    upper_bound_left = (140, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                    upper_bound_left)
    origin_left = tmp_cropped.origin

    spacing = tmp_cropped.spacing
    direction = tmp_cropped.direction

    t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left,
                                             upper_bound_left)
    t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min(
    )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0
    t2_template_roi_left = None
    if t2_template is not None:
        t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left,
                                                 upper_bound_left)
        t2_template_roi_left = (t2_template_roi_left -
                                t2_template_roi_left.min()) / (
                                    t2_template_roi_left.max() -
                                    t2_template_roi_left.min()) * 2.0 - 1.0

    labels_right = labels[2::2]
    priors_image_right_list = priors_image_list[2::2]
    probability_images_right = list()
    foreground_probability_images_right = list()
    lower_bound_right = (20, 74, 56)
    upper_bound_right = (84, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                    upper_bound_right)
    origin_right = tmp_cropped.origin

    t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right,
                                              upper_bound_right)
    t1_template_roi_right = (
        t1_template_roi_right - t1_template_roi_right.min()
    ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0
    t2_template_roi_right = None
    if t2_template is not None:
        t2_template_roi_right = ants.crop_indices(t2_template,
                                                  lower_bound_right,
                                                  upper_bound_right)
        t2_template_roi_right = (t2_template_roi_right -
                                 t2_template_roi_right.min()) / (
                                     t2_template_roi_right.max() -
                                     t2_template_roi_right.min()) * 2.0 - 1.0

    ################################
    #
    # Create model
    #
    ################################

    channel_size = 1 + len(labels_left)
    if t2 is not None:
        channel_size += 1

    number_of_classification_labels = 1 + len(labels_left)

    unet_model = create_unet_model_3d(
        (*image_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        mode="classification",
        number_of_filters=(32, 64, 96, 128, 256),
        convolution_kernel_size=(3, 3, 3),
        deconvolution_kernel_size=(2, 2, 2),
        dropout_rate=0.0,
        weight_decay=0)

    penultimate_layer = unet_model.layers[-2].output

    # medial temporal lobe
    output1 = Conv3D(
        filters=1,
        kernel_size=(1, 1, 1),
        activation='sigmoid',
        kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

    if use_hierarchical_parcellation:

        # EC, perirhinal, and parahippo.
        output2 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        # Hippocampus
        output3 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        unet_model = Model(
            inputs=unet_model.input,
            outputs=[unet_model.output, output1, output2, output3])
    else:
        unet_model = Model(inputs=unet_model.input,
                           outputs=[unet_model.output, output1])

    ################################
    #
    # Left:  build model and load weights
    #
    ################################

    network_name = 'deepFlashLeftT1'
    if t2 is not None:
        network_name = 'deepFlashLeftBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (left).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Left:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (left).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                   upper_bound_left)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_left, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_left, upper_bound_left)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_left, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left,
                                       upper_bound_left)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_left, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_left, upper_bound_left)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_left, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_left_list)):
        cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                          lower_bound_left, upper_bound_left)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_left))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_left)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_left, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                probability_images_left.append(probability_image)
            else:  # flipped
                probability_images_right.append(probability_image)

    ################################
    #
    # Left:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_left, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                foreground_probability_images_left.append(probability_image)
            else:
                foreground_probability_images_right.append(probability_image)

    ################################
    #
    # Right:  build model and load weights
    #
    ################################

    network_name = 'deepFlashRightT1'
    if t2 is not None:
        network_name = 'deepFlashRightBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (right).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Right:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (right).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                   upper_bound_right)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_right, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_right, upper_bound_right)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_right, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right,
                                       upper_bound_right)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_right, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_right,
                                           upper_bound_right)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_right, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_right_list)):
        cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                          lower_bound_right, upper_bound_right)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_right))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_right)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_right, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    probability_images_right[i] = (
                        probability_images_right[i] + probability_image) / 2
                else:
                    probability_images_right.append(probability_image)
            else:  # flipped
                probability_images_left[i] = (probability_images_left[i] +
                                              probability_image) / 2

    ################################
    #
    # Right:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_right, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    foreground_probability_images_right[
                        i - 1] = (foreground_probability_images_right[i - 1] +
                                  probability_image) / 2
                else:
                    foreground_probability_images_right.append(
                        probability_image)
            else:
                foreground_probability_images_left[
                    i - 1] = (foreground_probability_images_left[i - 1] +
                              probability_image) / 2

    ################################
    #
    # Combine priors
    #
    ################################

    probability_background_image = ants.image_clone(t1) * 0
    for i in range(1, len(probability_images_left)):
        probability_background_image += probability_images_left[i]
    for i in range(1, len(probability_images_right)):
        probability_background_image += probability_images_right[i]

    probability_images.append(probability_background_image * -1 + 1)
    for i in range(1, len(probability_images_left)):
        probability_images.append(probability_images_left[i])
        probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    foreground_probability_images = list()
    for i in range(len(foreground_probability_images_left)):
        foreground_probability_images.append(
            foreground_probability_images_left[i] +
            foreground_probability_images_right[i])

    return_dict = None
    if use_hierarchical_parcellation:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0],
            'other_region_probability_image':
            foreground_probability_images[1],
            'hippocampal_probability_image':
            foreground_probability_images[2]
        }
    else:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0]
        }

    return (return_dict)
コード例 #8
0
if not os.path.exists(weights_file_name):
    weights_file_name = antspynet.get_pretrained_network(
        "brainExtraction", weights_file_name)

unet_model.load_weights(weights_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

start_time_total = time.time()

print("Reading ", input_file_name)
start_time = time.time()
image = ants.image_read(input_file_name)
mask = ants.image_read(input_mask_file_name)
mask = ants.threshold_image(mask, 0.5, 1.0, 1, 0)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Normalizing to template")
start_time = time.time()
center_of_mass_template = ants.get_center_of_mass(reorient_template)
center_of_mass_image = ants.get_center_of_mass(image)
translation = np.asarray(center_of_mass_image) - np.asarray(
    center_of_mass_template)
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
                                  center=np.asarray(center_of_mass_template),
                                  translation=translation)
warped_image = ants.apply_ants_transform_to_image(xfrm,
                                                  image,
コード例 #9
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                do_preprocessing=True,
                                use_ensemble=True,
                                use_axial_slices_only=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    use_axial_slices_only : boolean
        If True, use original implementation which was trained on axial slices.
        If False, use ANTsXNet variant implementation which applies the slice-by-slice
        models to all 3 dimensions and averages the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import crop_image_center
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    flair_preprocessed = flair
    if do_preprocessing == True:
        flair_preprocessing = preprocess_brain_image(
            flair,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=False,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        flair_preprocessed = flair_preprocessing["preprocessed_image"]

    number_of_channels = 1
    if t1 is not None:
        t1_preprocessed = t1
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(
                t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]
        number_of_channels = 2

    ################################
    #
    # Estimate mask
    #
    ################################

    brain_mask = None
    if verbose == True:
        print("Estimating brain mask.")
    if t1 is not None:
        brain_mask = brain_extraction(t1, modality="t1")
    else:
        brain_mask = brain_extraction(flair, modality="flair")

    reference_image = ants.make_image((200, 200, 200),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))

    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, flair_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, t1_preprocessed, reference_image)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean()
    std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std()
    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair

    if number_of_channels == 2:
        mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
        std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_models.append(
            create_sysu_media_unet_model_2d((200, 200, number_of_channels)))
        unet_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]
    if use_axial_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), (200, 200))
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), (200, 200))
                batchX[slice_count, :, :, 1] = t1_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(batchX, verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(batchX, verbose=verbose)
    prediction /= number_of_models

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice + 1

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average, flair)

    return (probability_image)
コード例 #10
0
 def test_threshold_image_example(self):
     image = ants.image_read(ants.get_ants_data('r16'))
     timage = ants.threshold_image(image, 0.5, 1e15)
コード例 #11
0
 def test_get_centroids_example(self):
     image = ants.image_read(ants.get_ants_data("r16"))
     image = ants.threshold_image(image, 90, 120)
     image = ants.label_clusters(image, 10)
     cents = ants.get_centroids(image)
コード例 #12
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             which_model="sysu",
             which_axes=2,
             number_of_simulations=0,
             sd_affine=0.01,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform White matter hyperintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * intensity truncation,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{do_preprocessing = True}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction, intensity truncation, brain extraction.

    which_model : string
        one of:
            * "sysu" -- same as the original sysu network (without site specific preprocessing),
            * "sysu-ri" -- same as "sysu" but using ranked intensity scaling for input images,
            * "sysuWithAttention" -- "sysu" with attention gating,
            * "sysuWithAttentionAndSite" -- "sysu" with attention gating with site branch (see "sysuWithSite"),
            * "sysuPlus" -- "sysu" with attention gating and nn-Unet activation,
            * "sysuPlusSeg" -- "sysuPlus" with deep_atropos segmentation in an additional channel, and
            * "sysuWithSite" -- "sysu" with global pooling on encoding channels to predict "site".
            * "sysuPlusSegWithSite" -- "sysuPlusSeg" combined with "sysuWithSite"
        In addition to both modalities, all models have T1-only and flair-only variants except
        for "sysuPlusSeg" (which only has a T1-only variant) or "sysu-ri" (which has neither single
        modality variant).

    which_axes : string or scalar or tuple/vector
        apply 2-D model to 1 or more axes.  In addition to a scalar
        or vector, e.g., which_axes = (0, 2), one can use "max" for the
        axis with maximum anisotropy (default) or "all" for all axes.

    number_of_simulations : integer
        Number of random affine perturbations to transform the input.

    sd_affine : float
        Define the standard deviation of the affine transformation parameter.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import deep_atropos
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import randomly_transform_image_data
    from ..utilities import pad_or_crop_image_to_size

    do_t1_only = False
    do_flair_only = False

    if flair is None and t1 is not None:
        do_t1_only = True
    elif flair is not None and t1 is None:
        do_flair_only = True

    use_t1_segmentation = False
    if "Seg" in which_model:
        if do_flair_only:
            raise ValueError("Segmentation requires T1.")
        else:
            use_t1_segmentation = True

    if use_t1_segmentation and do_preprocessing == False:
        raise ValueError(
            "Using the t1 segmentation requires do_preprocessing=True.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    do_slicewise = True

    if do_slicewise == False:

        raise ValueError("Not available.")

        # ################################
        # #
        # # Preprocess images
        # #
        # ################################

        # t1_preprocessed = t1
        # t1_preprocessing = None
        # if do_preprocessing == True:
        #     t1_preprocessing = preprocess_brain_image(t1,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         template="croppedMni152",
        #         template_transform_type="antsRegistrationSyNQuickRepro[a]",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        # flair_preprocessed = flair
        # if do_preprocessing == True:
        #     flair_preprocessing = preprocess_brain_image(flair,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
        #         moving=flair_preprocessing["preprocessed_image"],
        #         transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
        #     flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        # ################################
        # #
        # # Build model and load weights
        # #
        # ################################

        # patch_size = (112, 112, 112)
        # stride_length = (t1_preprocessed.shape[0] - patch_size[0],
        #                 t1_preprocessed.shape[1] - patch_size[1],
        #                 t1_preprocessed.shape[2] - patch_size[2])

        # classes = ("background", "wmh" )
        # number_of_classification_labels = len(classes)
        # labels = (0, 1)

        # image_modalities = ("T1", "FLAIR")
        # channel_size = len(image_modalities)

        # unet_model = create_unet_model_3d((*patch_size, channel_size),
        #     number_of_outputs = number_of_classification_labels,
        #     number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
        #     convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
        #     weight_decay = 1e-5, additional_options=("attentionGating"))

        # weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
        #     antsxnet_cache_directory=antsxnet_cache_directory)
        # unet_model.load_weights(weights_file_name)

        # ################################
        # #
        # # Do prediction and normalize to native space
        # #
        # ################################

        # if verbose == True:
        #     print("ew_david:  prediction.")

        # batchX = np.zeros((8, *patch_size, channel_size))

        # t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        # t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,0] = t1_patches

        # flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        # flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,1] = flair_patches

        # predicted_data = unet_model.predict(batchX, verbose=verbose)

        # probability_images = list()
        # for i in range(len(labels)):
        #     print("Reconstructing image", classes[i])
        #     reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
        #         domain_image=t1_preprocessed, stride_length=stride_length)

        #     if do_preprocessing == True:
        #         probability_images.append(ants.apply_transforms(fixed=t1,
        #             moving=reconstructed_image,
        #             transformlist=t1_preprocessing['template_transforms']['invtransforms'],
        #             whichtoinvert=[True], interpolator="linear", verbose=verbose))
        #     else:
        #         probability_images.append(reconstructed_image)

        # return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        use_rank_intensity_scaling = False
        if "-ri" in which_model:
            use_rank_intensity_scaling = True

        t1_preprocessed = None
        t1_preprocessing = None
        brain_mask = None
        if t1 is not None:
            if do_preprocessing == True:
                t1_preprocessing = preprocess_brain_image(
                    t1,
                    truncate_intensity=(0.01, 0.995),
                    brain_extraction_modality="t1",
                    do_bias_correction=False,
                    do_denoising=False,
                    antsxnet_cache_directory=antsxnet_cache_directory,
                    verbose=verbose)
                brain_mask = ants.threshold_image(
                    t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                t1_preprocessed = t1_preprocessing["preprocessed_image"]

        t1_segmentation = None
        if use_t1_segmentation:
            atropos_seg = deep_atropos(t1,
                                       do_preprocessing=True,
                                       verbose=verbose)
            t1_segmentation = atropos_seg['segmentation_image']

        flair_preprocessed = None
        if flair is not None:
            flair_preprocessed = flair
            if do_preprocessing == True:
                if brain_mask is None:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=(0.01, 0.995),
                        brain_extraction_modality="flair",
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                    brain_mask = ants.threshold_image(
                        flair_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                else:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=None,
                        brain_extraction_modality=None,
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                flair_preprocessed = flair_preprocessing["preprocessed_image"]

        if t1_preprocessed is not None:
            t1_preprocessed = t1_preprocessed * brain_mask
        if flair_preprocessed is not None:
            flair_preprocessed = flair_preprocessed * brain_mask

        if t1_preprocessed is not None:
            resampling_params = list(ants.get_spacing(t1_preprocessed))
        else:
            resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            if flair_preprocessed is not None:
                flair_preprocessed = ants.resample_image(flair_preprocessed,
                                                         resampling_params,
                                                         use_voxels=False,
                                                         interp_type=0)
            if t1_preprocessed is not None:
                t1_preprocessed = ants.resample_image(t1_preprocessed,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=0)
            if t1_segmentation is not None:
                t1_segmentation = ants.resample_image(t1_segmentation,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=1)
            if brain_mask is not None:
                brain_mask = ants.resample_image(brain_mask,
                                                 resampling_params,
                                                 use_voxels=False,
                                                 interp_type=1)

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (208, 208)

        image_modalities = ("T1", "FLAIR")
        if do_flair_only:
            image_modalities = ("FLAIR", )
        elif do_t1_only:
            image_modalities = ("T1", )
        if use_t1_segmentation:
            image_modalities = (*image_modalities, "T1Seg")
        channel_size = len(image_modalities)

        unet_model = None
        if which_model == "sysu" or which_model == "sysu-ri":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuWithAttention":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithAttentionAndSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuPlusSegWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        else:
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = None
        if which_model == "sysu" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysu",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu-ri" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuRankedIntensity",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttention",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlus",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSeg",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            raise ValueError(
                "Incorrect model specification or image combination.")

        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Data augmentation and extract slices
        #
        ################################

        wmh_probability_image = None
        if t1 is not None:
            wmh_probability_image = ants.image_clone(t1_preprocessed) * 0
        else:
            wmh_probability_image = ants.image_clone(flair_preprocessed) * 0

        wmh_site = np.array([0, 0, 0])

        data_augmentation = None
        if number_of_simulations > 0:
            if do_flair_only:
                data_augmentation = randomly_transform_image_data(
                    reference_image=flair_preprocessed,
                    input_image_list=[[flair_preprocessed]],
                    number_of_simulations=number_of_simulations,
                    transform_type='affine',
                    sd_affine=sd_affine,
                    input_image_interpolator='linear')
            elif do_t1_only:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')
            else:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')

        dimensions_to_predict = list((0, ))
        if which_axes == "max":
            spacing = ants.get_spacing(wmh_probability_image)
            dimensions_to_predict = (spacing.index(max(spacing)), )
        elif which_axes == "all":
            dimensions_to_predict = list(range(3))
        else:
            if isinstance(which_axes, int):
                dimensions_to_predict = list((which_axes, ))
            else:
                dimensions_to_predict = list(which_axes)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += wmh_probability_image.shape[
                dimensions_to_predict[d]]

        batchX = np.zeros(
            (total_number_of_slices, *template_size, channel_size))

        for n in range(number_of_simulations + 1):

            batch_flair = flair_preprocessed
            batch_t1 = t1_preprocessed
            batch_t1_segmentation = t1_segmentation
            batch_brain_mask = brain_mask

            if n > 0:

                if do_flair_only:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")

                elif do_t1_only:
                    batch_t1 = data_augmentation['simulated_images'][n - 1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        t1_preprocessed,
                        interpolation="nearestneighbor")
                else:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_t1 = data_augmentation['simulated_images'][n - 1][1]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")
                if use_t1_segmentation:
                    batch_t1_segmentation = data_augmentation[
                        'simulated_segmentation_images'][n - 1]

            if use_rank_intensity_scaling:
                if batch_t1 is not None:
                    batch_t1 = ants.rank_intensity(batch_t1,
                                                   batch_brain_mask) - 0.5
                if batch_flair is not None:
                    batch_flair = ants.rank_intensity(flair_preprocessed,
                                                      batch_brain_mask) - 0.5
            else:
                if batch_t1 is not None:
                    batch_t1 = (batch_t1 -
                                batch_t1[batch_brain_mask == 1].mean()
                                ) / batch_t1[batch_brain_mask == 1].std()
                if batch_flair is not None:
                    batch_flair = (
                        batch_flair -
                        batch_flair[batch_brain_mask == 1].mean()
                    ) / batch_flair[batch_brain_mask == 1].std()

            slice_count = 0
            for d in range(len(dimensions_to_predict)):

                number_of_slices = None
                if batch_t1 is not None:
                    number_of_slices = batch_t1.shape[dimensions_to_predict[d]]
                else:
                    number_of_slices = batch_flair.shape[
                        dimensions_to_predict[d]]

                if verbose == True:
                    print("Extracting slices for dimension ",
                          dimensions_to_predict[d])

                for i in range(number_of_slices):

                    brain_mask_slice = pad_or_crop_image_to_size(
                        ants.slice_image(batch_brain_mask,
                                         dimensions_to_predict[d], i),
                        template_size)

                    channel_count = 0
                    if batch_flair is not None:
                        flair_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_flair,
                                             dimensions_to_predict[d], i),
                            template_size)
                        flair_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = flair_slice.numpy()
                        channel_count += 1
                    if batch_t1 is not None:
                        t1_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_slice.numpy()
                        channel_count += 1
                    if t1_segmentation is not None:
                        t1_segmentation_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1_segmentation,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_segmentation_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_segmentation_slice.numpy(
                               ) / 6 - 0.5

                    slice_count += 1

            ################################
            #
            # Do prediction and then restack into the image
            #
            ################################

            if verbose == True:
                if n == 0:
                    print("Prediction")
                else:
                    print("Prediction (simulation " + str(n) + ")")

            prediction = unet_model.predict(batchX, verbose=verbose)

            permutations = list()
            permutations.append((0, 1, 2))
            permutations.append((1, 0, 2))
            permutations.append((1, 2, 0))

            prediction_image_average = ants.image_clone(
                wmh_probability_image) * 0

            current_start_slice = 0
            for d in range(len(dimensions_to_predict)):
                current_end_slice = current_start_slice + wmh_probability_image.shape[
                    dimensions_to_predict[d]]
                which_batch_slices = range(current_start_slice,
                                           current_end_slice)
                if isinstance(prediction, list):
                    prediction_per_dimension = prediction[0][
                        which_batch_slices, :, :, 0]
                else:
                    prediction_per_dimension = prediction[
                        which_batch_slices, :, :, 0]
                prediction_array = np.transpose(
                    np.squeeze(prediction_per_dimension),
                    permutations[dimensions_to_predict[d]])
                prediction_image = ants.copy_image_info(
                    wmh_probability_image,
                    pad_or_crop_image_to_size(
                        ants.from_numpy(prediction_array),
                        wmh_probability_image.shape))
                prediction_image_average = prediction_image_average + (
                    prediction_image - prediction_image_average) / (d + 1)
                current_start_slice = current_end_slice

            wmh_probability_image = wmh_probability_image + (
                prediction_image_average - wmh_probability_image) / (n + 1)
            if isinstance(prediction, list):
                wmh_site = wmh_site + (np.mean(prediction[1], axis=0) -
                                       wmh_site) / (n + 1)

        if do_resampling:
            if t1 is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, t1)
            if flair is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, flair)

        if isinstance(prediction, list):
            return ([wmh_probability_image, wmh_site])
        else:
            return (wmh_probability_image)
コード例 #13
0
def hippmapp3r_segmentation(t1,
                            do_preprocessing=True,
                            antsxnet_cache_directory=None,
                            verbose=False):
    """
    Perform HippMapp3r (hippocampal) segmentation described in

     https://www.ncbi.nlm.nih.gov/pubmed/31609046

    with models and architecture ported from

    https://github.com/mgoubran/HippMapp3r

    Additional documentation and attribution resources found at

    https://hippmapp3r.readthedocs.io/en/latest/

    Preprocessing consists of:
       * n4 bias correction and
       * brain extraction
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        input image

    do_preprocessing : boolean
        See description above.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs labeled hippocampal image.

    Example
    -------
    >>> mask = hippmapp3r_segmentation(t1)
    """

    from ..architectures import create_hippmapp3r_unet_model_3d
    from ..utilities import preprocess_brain_image
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if verbose == True:
        print("*************  Preprocessing  ***************")
        print("")

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality="t1",
            template=None,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    if verbose == True:
        print("*************  Initial stage segmentation  ***************")
        print("")

    # Normalize to mprage_hippmapp3r space
    if verbose == True:
        print("    HippMapp3r: template normalization.")

    template_file_name_path = get_antsxnet_data(
        "mprage_hippmapp3r", antsxnet_cache_directory=antsxnet_cache_directory)
    template_image = ants.image_read(template_file_name_path)

    registration = ants.registration(
        fixed=template_image,
        moving=t1_preprocessed,
        type_of_transform="antsRegistrationSyNQuickRepro[t]",
        verbose=verbose)
    image = registration['warpedmovout']
    transforms = dict(fwdtransforms=registration['fwdtransforms'],
                      invtransforms=registration['invtransforms'])

    # Threshold at 10th percentile of non-zero voxels in "robust range (fslmaths)"
    if verbose == True:
        print("    HippMapp3r: threshold.")

    image_array = image.numpy()
    image_robust_range = np.quantile(image_array[np.where(image_array != 0)],
                                     (0.02, 0.98))
    threshold_value = 0.10 * (image_robust_range[1] -
                              image_robust_range[0]) + image_robust_range[0]
    thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0,
                                            1)
    thresholded_image = image * thresholded_mask

    # Standardize image
    if verbose == True:
        print("    HippMapp3r: standardize.")

    mean_image = np.mean(thresholded_image[thresholded_mask == 1])
    sd_image = np.std(thresholded_image[thresholded_mask == 1])
    image_normalized = (image - mean_image) / sd_image
    image_normalized = image_normalized * thresholded_mask

    # Trim and resample image
    if verbose == True:
        print("    HippMapp3r: trim and resample to (160, 160, 128).")

    image_cropped = ants.crop_image(image_normalized, thresholded_mask, 1)
    shape_initial_stage = (160, 160, 128)
    image_resampled = ants.resample_image(image_cropped,
                                          shape_initial_stage,
                                          use_voxels=True,
                                          interp_type=1)

    if verbose == True:
        print("    HippMapp3r: generate first network and download weights.")

    model_initial_stage = create_hippmapp3r_unet_model_3d(
        (*shape_initial_stage, 1), do_first_network=True)

    initial_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rInitial", antsxnet_cache_directory=antsxnet_cache_directory)
    model_initial_stage.load_weights(initial_stage_weights_file_name)

    if verbose == True:
        print("    HippMapp3r: prediction.")

    data_initial_stage = np.expand_dims(image_resampled.numpy(), axis=0)
    data_initial_stage = np.expand_dims(data_initial_stage, axis=-1)
    mask_array = model_initial_stage.predict(data_initial_stage,
                                             verbose=verbose)
    mask_image_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(np.squeeze(mask_array)))
    mask_image = ants.resample_image(mask_image_resampled,
                                     image.shape,
                                     use_voxels=True,
                                     interp_type=0)
    mask_image[mask_image >= 0.5] = 1
    mask_image[mask_image < 0.5] = 0

    #########################################
    #
    # Perform refined (stage 2) segmentation
    #

    if verbose == True:
        print("")
        print("")
        print("*************  Refine stage segmentation  ***************")
        print("")

    mask_array = np.squeeze(mask_array)
    centroid_indices = np.where(mask_array == 1)
    centroid = np.zeros((3, ))
    centroid[0] = centroid_indices[0].mean()
    centroid[1] = centroid_indices[1].mean()
    centroid[2] = centroid_indices[2].mean()

    shape_refine_stage = (112, 112, 64)
    lower = (np.floor(centroid - 0.5 * np.array(shape_refine_stage)) -
             1).astype(int)
    upper = (lower + np.array(shape_refine_stage)).astype(int)

    image_trimmed = ants.crop_indices(image_resampled, lower.astype(int),
                                      upper.astype(int))

    if verbose == True:
        print("    HippMapp3r: generate second network and download weights.")

    model_refine_stage = create_hippmapp3r_unet_model_3d(
        (*shape_refine_stage, 1), do_first_network=False)

    refine_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rRefine", antsxnet_cache_directory=antsxnet_cache_directory)
    model_refine_stage.load_weights(refine_stage_weights_file_name)

    data_refine_stage = np.expand_dims(image_trimmed.numpy(), axis=0)
    data_refine_stage = np.expand_dims(data_refine_stage, axis=-1)

    if verbose == True:
        print("    HippMapp3r: Monte Carlo iterations (SpatialDropout).")

    number_of_mci_iterations = 30
    prediction_refine_stage = np.zeros(shape_refine_stage)
    for i in range(number_of_mci_iterations):
        tf.random.set_seed(i)
        if verbose == True:
            print("        Monte Carlo iteration", i + 1, "out of",
                  number_of_mci_iterations)
        prediction_refine_stage = \
            (np.squeeze(model_refine_stage.predict(data_refine_stage, verbose=verbose)) + \
             i * prediction_refine_stage ) / (i + 1)

    prediction_refine_stage_array = np.zeros(image_resampled.shape)
    prediction_refine_stage_array[lower[0]:upper[0], lower[1]:upper[1],
                                  lower[2]:upper[2]] = prediction_refine_stage
    probability_mask_refine_stage_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(prediction_refine_stage_array))

    segmentation_image_resampled = ants.label_clusters(ants.threshold_image(
        probability_mask_refine_stage_resampled, 0.0, 0.5, 0, 1),
                                                       min_cluster_size=10)
    segmentation_image_resampled[segmentation_image_resampled > 2] = 0
    geom = ants.label_geometry_measures(segmentation_image_resampled)
    if len(geom['VolumeInMillimeters']) < 2:
        raise ValueError("Error: left and right hippocampus not found.")

    if geom['Centroid_x'][0] < geom['Centroid_x'][1]:
        segmentation_image_resampled[segmentation_image_resampled == 1] = 3
        segmentation_image_resampled[segmentation_image_resampled == 2] = 1
        segmentation_image_resampled[segmentation_image_resampled == 3] = 2

    segmentation_image = ants.apply_transforms(
        fixed=t1,
        moving=segmentation_image_resampled,
        transformlist=transforms['invtransforms'],
        whichtoinvert=[True],
        interpolator="genericLabel",
        verbose=verbose)

    return (segmentation_image)
コード例 #14
0
        "brainSegmentationPatchBased", weights_file_name)

unet_model.load_weights(weights_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

# Process input

start_time_total = time.time()

print("Reading ", input_file_name)
start_time = time.time()
image = ants.image_read(input_file_name)
mask = ants.image_read(input_mask_file_name)
mask = ants.threshold_image(mask, 0.4999, 1.0001, 1, 0)
image = image * mask
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Extracting patches based on mask.")
start_time = time.time()

image_patches = antspynet.extract_image_patches(image,
                                                stride_length=stride_length,
                                                patch_size=patch_size,
                                                max_number_of_patches="all",
                                                mask_image=mask,
                                                return_as_array=True)
image_patches = (image_patches - image_patches.min()) / (image_patches.max() -
コード例 #15
0
def neural_style_transfer(content_image,
                          style_images,
                          initial_combination_image=None,
                          number_of_iterations=10,
                          learning_rate=1.0,
                          total_variation_weight=8.5e-5,
                          content_weight=0.025,
                          style_image_weights=1.0,
                          content_layer_names=['block5_conv2'],
                          style_layer_names="all",
                          content_mask=None,
                          style_masks=None,
                          use_shifted_activations=True,
                          use_chained_inference=True,
                          verbose=False,
                          output_prefix=None):
    """
    The popular neural style transfer described here:

    https://arxiv.org/abs/1508.06576 and https://arxiv.org/abs/1605.04603

    and taken from François Chollet's implementation

    https://keras.io/examples/generative/neural_style_transfer/

    and titu1994's modifications:

    https://github.com/titu1994/Neural-Style-Transfer

    in order to possibly modify and experiment with medical images.

    Arguments
    ---------
    content_image : ANTsImage (1 or 3-component)
        Content (or base) image.

    style_images : ANTsImage or list of ANTsImages
        Style (or reference) image.

    initial_combination_image : ANTsImage (1 or 3-component)
        Starting point for the optimization.  Allows one to start from the
        output from a previous run.  Otherwise, start from the content image.
        Note that the original paper starts with a noise image.

    number_of_iterations : integer
        Number of gradient steps taken during optimization.

    learning_rate : float
        Parameter for Adam optimization.

    total_variation_weight : float
        A penalty on the regularization term to keep the features
        of the output image locally coherent.

    content_weight : float
        Weight of the content layers in the optimization function.

    style_image_weights : float or list of floats
        Weights of the style term in the optimization function for each
        style image.  Can either specify a single scalar to be used for
        all the images or one for each image.  The
        style term computes the sum of the L2 norm between the Gram
        matrices of the different layers (using ImageNet-trained VGG)
        of the style and content images.

    content_layer_names : list of strings
        Names of VGG layers from which to compute the content loss.

    style_layer_names : list of strings
        Names of VGG layers from which to compute the style loss.  If "all",
        the layers used are ['block1_conv1', 'block1_conv2', 'block2_conv1',
        'block2_conv2', 'block3_conv1', 'block3_conv2', 'block3_conv3',
        'block3_conv4', 'block4_conv1', 'block4_conv2', 'block4_conv3',
        'block4_conv4', 'block5_conv1', 'block5_conv2', 'block5_conv3',
        'block5_conv4'].  This is a proposed improvement from
        https://arxiv.org/abs/1605.04603.  In the original implementation, the
        layers used are: ['block1_conv1', 'block2_conv1', 'block3_conv1',
         'block4_conv1', 'block5_conv1'].

    content_mask : ANTsImage
        Specify the region for content consideration.

    style_masks :  ANTsImage or list of ANTsImages
        Specify the region for style consideration.

    use_shifted_activations : boolean
        Use shifted activations in calculating the Gram matrix (improvement
        mentioned in https://arxiv.org/abs/1605.04603).

    use_chained_inference : boolean
        Another proposed improvement from https://arxiv.org/abs/1605.04603.

    verbose : boolean
        Print progress to the screen.

    output_prefix : string
        If specified, outputs a png image to disk at each iteration.

    Returns
    -------
    ANTs 3-component image.

    Example
    -------
    >>> image = neural_style_transfer(content_image, style_image)
    """
    def preprocess_ants_image(image, do_scale_and_center=True):
        array = None
        if image.components == 1:
            array = image.numpy()
            array = np.expand_dims(array, 2)
            array = np.repeat(array, 3, 2)
        elif image.components == 3:
            vector_image = image
            image_channels = ants.split_channels(vector_image)
            array = np.concatenate([
                np.expand_dims(image_channels[0].numpy(), axis=2),
                np.expand_dims(image_channels[1].numpy(), axis=2),
                np.expand_dims(image_channels[2].numpy(), axis=2)
            ],
                                   axis=2)
        else:
            raise ValueError("Unexpected number of components.")

        if do_scale_and_center == True:
            for i in range(3):
                array[:, :, i] = (array[:, :, i] - array[:, :, i].min()) / (
                    array[:, :, i].max() - array[:, :, i].min())
            array *= 255.0
            # RGB -> BGR
            array = array[:, :, ::-1]
            array[:, :, 0] -= 103.939
            array[:, :, 1] -= 116.779
            array[:, :, 2] -= 123.68

        array = np.expand_dims(array, 0)
        return (array)

    def postprocess_array(array, reference_image):
        array = np.squeeze(array)

        array[:, :, 0] += 103.939
        array[:, :, 1] += 116.779
        array[:, :, 2] += 123.68
        # BGR -> RGB
        array = array[:, :, ::-1]
        array = np.clip(array, 0, 255)

        image = ants.from_numpy(array,
                                origin=reference_image.origin,
                                spacing=reference_image.spacing,
                                direction=reference_image.direction,
                                has_components=True)
        return (image)

    def gram_matrix(x, shifted_activations=False):
        F = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
        if shifted_activations:
            F = F - 1
        gram = K.dot(F, K.transpose(F))
        return (gram)

    def process_mask(mask, shape):
        mask_processed = (tf.image.resize(
            mask,
            size=[shape[0], shape[1]],
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)).numpy()
        mask_processed_tensor = np.empty(shape)
        for i in range(shape[2]):
            mask_processed_tensor[:, :, i] = mask_processed[:, :, 0]
        return (mask_processed_tensor)

    def style_loss(style_features,
                   combination_features,
                   image_shape,
                   style_mask=None,
                   content_mask=None):

        if content_mask is not None:
            mask_tensor = K.variable(
                process_mask(content_mask, combination_features.shape))
            combination_features = combination_features * K.stop_gradient(
                mask_tensor)
            del mask_tensor

        if style_mask is not None:
            mask_tensor = K.variable(
                process_mask(style_mask, style_features.shape))
            style_features = style_features * K.stop_gradient(mask_tensor)
            if content_mask is not None:
                combination_features = combination_features * K.stop_gradient(
                    mask_tensor)
            del mask_tensor

        style_gram = gram_matrix(style_features, use_shifted_activations)
        content_gram = gram_matrix(combination_features,
                                   use_shifted_activations)
        size = image_shape[0] * image_shape[1]
        number_of_channels = 3
        loss = tf.reduce_sum(
            tf.square(style_gram - content_gram)) / (4.0 *
                                                     (number_of_channels**2) *
                                                     (size**2))
        return (loss)

    def content_loss(content_features, combination_features):
        loss = tf.reduce_sum(tf.square(content_features -
                                       combination_features))
        return (loss)

    def total_variation_loss(x):
        shape = x.shape
        a = tf.square(x[:, :(shape[1] - 1), :(shape[2] - 1), :] -
                      x[:, 1:, :(shape[2] - 1), :])
        b = tf.square(x[:, :(shape[1] - 1), :(shape[2] - 1), :] -
                      x[:, :(shape[1] - 1), 1:, :])
        loss = tf.reduce_sum(tf.pow(a + b, 1.25))
        return (loss)

    def compute_total_loss(content_array,
                           style_array_list,
                           combination_tensor,
                           feature_model,
                           content_layer_names,
                           style_layer_names,
                           image_shape,
                           content_mask_tensor=None,
                           style_mask_tensor_list=None):
        number_of_style_images = len(style_array_list)

        input_arrays = list()
        input_arrays.append(content_array)
        for i in range(number_of_style_images):
            input_arrays.append(style_array_list[i])
        input_arrays.append(combination_tensor)
        input_tensor = tf.concat(input_arrays, axis=0)

        features = feature_model(input_tensor)

        total_loss = tf.zeros(shape=())

        # content loss
        for i in range(len(content_layer_names)):
            layer_features = features[content_layer_names[i]]
            content_features = layer_features[0, :, :, :]
            combination_features = layer_features[2, :, :, :]
            total_loss = total_loss + (
                content_loss(content_features, combination_features) *
                content_weight / len(content_layer_names))

        # style loss
        if use_chained_inference:
            for i in range(len(style_layer_names) - 1):
                layer_features = features[style_layer_names[i]]
                style_features = layer_features[1:(number_of_style_images +
                                                   1), :, :, :]
                combination_features = layer_features[number_of_style_images +
                                                      1, :, :, :]
                loss = list()
                for j in range(number_of_style_images):
                    if style_mask_tensor_list is None:
                        loss.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=None,
                                       content_mask=content_mask_tensor))
                    else:
                        loss.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=style_mask_tensor_list[j],
                                       content_mask=content_mask_tensor))

                layer_features = features[style_layer_names[i + 1]]
                style_features = layer_features[1:(number_of_style_images +
                                                   1), :, :, :]
                combination_features = layer_features[number_of_style_images +
                                                      1, :, :, :]
                loss_p1 = list()
                for j in range(number_of_style_images):
                    if style_mask_tensor_list is None:
                        loss_p1.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=None,
                                       content_mask=content_mask_tensor))
                    else:
                        loss_p1.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=style_mask_tensor_list[j],
                                       content_mask=content_mask_tensor))

                for j in range(number_of_style_images):
                    loss_difference = loss[j] - loss_p1[j]
                    total_loss = total_loss + (style_image_weights[j] *
                                               loss_difference /
                                               (2**(len(style_layer_names) -
                                                    (i + 1))))

        else:
            for i in range(len(style_layer_names)):
                layer_features = features[style_layer_names[i]]
                style_features = layer_features[1:(number_of_style_images +
                                                   1), :, :, :]
                combination_features = layer_features[number_of_style_images +
                                                      1, :, :, :]
                for j in range(number_of_style_images):
                    loss = list()
                    if style_mask_tensor_list is None:
                        loss.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=None,
                                       content_mask=content_mask_tensor))
                    else:
                        loss.append(
                            style_loss(style_features[j],
                                       combination_features,
                                       image_shape,
                                       style_mask=style_mask_tensor_list[j],
                                       content_mask=content_mask_tensor))

                for j in range(number_of_style_images):
                    total_loss = total_loss + (loss[j] * style_image_weights[j]
                                               / len(style_layer_names))

        # total variation loss
        total_loss = total_loss + total_variation_weight * total_variation_loss(
            combination_tensor)

        return (total_loss)

    def compute_loss_and_gradients(content_array,
                                   style_array_list,
                                   combination_tensor,
                                   feature_model,
                                   content_layer_names,
                                   style_layer_names,
                                   image_shape,
                                   content_mask_tensor=None,
                                   style_mask_tensor_list=None):
        with tf.GradientTape() as tape:
            loss = compute_total_loss(content_array, style_array_list,
                                      combination_tensor, feature_model,
                                      content_layer_names, style_layer_names,
                                      image_shape, content_mask_tensor,
                                      style_mask_tensor_list)
        gradients = tape.gradient(loss, combination_tensor)
        return loss, gradients

    number_of_style_images = 1
    if isinstance(style_images, list):
        number_of_style_images = len(style_images)

    style_image_list = list()
    if number_of_style_images == 1:
        style_image_list.append(style_images)
    else:
        style_image_list = style_images

    for i in range(number_of_style_images):
        if style_image_list[i].dimension != 2:
            raise ValueError("Input style images must be 2-D.")
        if style_image_list[i].shape != content_image.shape:
            raise ValueError(
                "Input images must have matching dimensions/shapes.")

    number_of_style_masks = 0
    style_mask_tensor_list = None
    if style_masks is not None:
        number_of_style_masks = 1
        if isinstance(style_masks, list):
            number_of_style_masks = len(style_masks)

        style_mask_tensor_list = list()
        if number_of_style_masks == 1:
            style_mask_array = (ants.threshold_image(style_masks, 0, 0, 0,
                                                     1)).numpy()
            style_mask_tensor = np.expand_dims(style_mask_array, -1)
            style_mask_tensor_list.append(style_mask_tensor)
        else:
            for i in range(len(style_masks)):
                style_mask_array = (ants.threshold_image(
                    style_masks[i], 0, 0, 0, 1)).numpy()
                style_mask_tensor = np.expand_dims(style_mask_array, -1)
                style_mask_tensor_list.append(style_mask_tensor)

    if number_of_style_masks > 0 and number_of_style_images != number_of_style_masks:
        raise ValueError("The number of style images/masks are not the same.")

    if isinstance(style_image_weights, (int, float)):
        style_image_weights = [style_image_weights] * len(style_image_list)
    else:
        if len(style_image_weights) == 1:
            style_image_weights = style_image_weights * len(style_image_list)
        elif not len(style_image_weights) == len(style_image_list):
            raise ValueError(
                "Length of style weights must be 1 or the number of style images."
            )

    if content_image.dimension != 2:
        raise ValueError("Input content image must be 2-D.")

    content_mask_tensor = None
    if content_mask is not None:
        content_mask_array = (ants.threshold_image(content_mask, 0, 0, 0,
                                                   1)).numpy()
        content_mask_tensor = np.expand_dims(content_mask_array, -1)

    if style_layer_names == "all":
        style_layer_names = [
            'block1_conv1', 'block1_conv2', 'block2_conv1', 'block2_conv2',
            'block3_conv1', 'block3_conv2', 'block3_conv3', 'block3_conv4',
            'block4_conv1', 'block4_conv2', 'block4_conv3', 'block4_conv4',
            'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_conv4'
        ]

    model = vgg19.VGG19(weights="imagenet", include_top=False)

    outputs_dictionary = dict([(layer.name, layer.output)
                               for layer in model.layers])
    # shapes_dictionary = dict([(layer.name, layer.output_shape) for layer in model.layers])

    feature_model = tf.keras.Model(inputs=model.inputs,
                                   outputs=outputs_dictionary)

    # Preprocess data
    content_array = preprocess_ants_image(content_image)
    style_array_list = list()
    for i in range(number_of_style_images):
        style_array_list.append(preprocess_ants_image(style_image_list[i]))

    image_shape = (content_array.shape[1], content_array.shape[2], 3)

    combination_tensor = None
    if initial_combination_image is None:
        combination_tensor = tf.Variable(np.copy(content_array))
    else:
        initial_combination_tensor = preprocess_ants_image(
            initial_combination_image, do_scale_and_center=False)
        combination_tensor = tf.Variable(initial_combination_tensor)

    if not image_shape == (combination_tensor.shape[1],
                           combination_tensor.shape[2], 3):
        raise ValueError(
            "Initial combination image size does not match content image.")

    optimizer = tf.optimizers.Adam(learning_rate=learning_rate,
                                   beta_1=0.99,
                                   epsilon=0.1)

    for i in range(number_of_iterations):
        start_time = time.time()
        loss, gradients = compute_loss_and_gradients(
            content_array, style_array_list, combination_tensor, feature_model,
            content_layer_names, style_layer_names, image_shape,
            content_mask_tensor, style_mask_tensor_list)
        end_time = time.time()
        if verbose == True:
            print(
                "Iteration %d of %d: total loss = %.2f (elapsed time = %ds)" %
                (i, number_of_iterations, loss, end_time - start_time))
        optimizer.apply_gradients([(gradients, combination_tensor)])

        if not output_prefix == None:
            combination_array = combination_tensor.numpy()
            combination_image = postprocess_array(combination_array,
                                                  content_image)
            combination_rgb = combination_image.vector_to_rgb()
            ants.image_write(combination_rgb,
                             output_prefix + "_iteration%d.png" % (i + 1))

    combination_array = combination_tensor.numpy()
    combination_image = postprocess_array(combination_array, content_image)
    return (combination_image)
コード例 #16
0
ファイル: testNoBrainer.py プロジェクト: kumeS/TestingScripts
    with open(target_file_image_name, 'wb') as f:
        f.write(r.content)

image = ants.image_read(target_file_image_name)

print("Preprocessing: bias correction")
image_n4 = ants.n4_bias_field_correction(image)
image_n4 = ants.image_math(image_n4, 'Normalize') * 255.0

print("Preprocessing:  thresholding")
image_n4_array = ((image_n4.numpy()).flatten())
image_n4_nonzero = image_n4_array[(image_n4_array > 0).nonzero()]
image_robust_range = np.quantile(image_n4_nonzero, (0.02, 0.98))
threshold_value = 0.10 * (image_robust_range[1] -
                          image_robust_range[0]) + image_robust_range[0]
thresholded_mask = ants.threshold_image(image_n4, -10000, threshold_value, 0,
                                        1)
thresholded_image = image_n4 * thresholded_mask

print("Preprocessing:  resampling")
image_resampled = ants.resample_image(thresholded_image, (256, 256, 256), True)
batchX = np.expand_dims(image_resampled.numpy(), axis=0)
batchX = np.expand_dims(batchX, axis=-1)

print("Prediction and write to disk.")
brain_mask_array = model.predict(batchX, verbose=0)
brain_mask_resampled = ants.from_numpy(np.squeeze(brain_mask_array[0, :, :, :,
                                                                   0]),
                                       origin=image_resampled.origin,
                                       spacing=image_resampled.spacing,
                                       direction=image_resampled.direction)
brain_mask_image = ants.resample_image(brain_mask_resampled, image.shape, True,
コード例 #17
0
print("    Initial step 1: bias correction.")
start_time = time.time()
image_n4 = warped_image  # ants.n4_bias_field_correction(warped_image)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

# Threshold at 10th percentile of non-zero voxels in "robust range (fslmaths)"
print("    Initial step 2: threshold.")
start_time = time.time()
image_n4_array = ((image_n4.numpy()).flatten())
image_n4_nonzero = image_n4_array[(image_n4_array > 0).nonzero()]
image_robust_range = np.quantile(image_n4_nonzero, (0.02, 0.98))
threshold_value = 0.10 * (image_robust_range[1] -
                          image_robust_range[0]) + image_robust_range[0]
thresholded_mask = ants.threshold_image(image_n4, -10000, threshold_value, 0,
                                        1)
thresholded_image = image_n4 * thresholded_mask
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

# Standardize image (should do patch-based stuff but making a quicker substitute for testing)
print("    Initial step 3: standardize.")
start_time = time.time()
thresholded_array = ((thresholded_image.numpy()).flatten())
thresholded_nonzero = image_n4_array[(thresholded_array > 0).nonzero()]
image_mean = np.mean(thresholded_nonzero)
image_sd = np.std(thresholded_nonzero)
image_standard = (image_n4 - image_mean) / image_sd
image_standard = image_standard * thresholded_mask
end_time = time.time()
コード例 #18
0
####################
#
# Run longitudinal cortical thickness
#
####################

kk_long = antspynet.longitudinal_cortical_thickness(t1s, initial_template="adni",
  number_of_iterations=2, refinement_transform="antsRegistrationSyNQuick[a]", verbose=True)

sst_file = output_directory + "/" + "SingleSubjectTemplate.nii.gz"
ants.image_write(kk_long[-1], sst_file)

for i in range(len(t1s)):
    kk_file = output_directory + "/" + t1s_basename_prefix[i] + "CorticalThickness.nii.gz"
    ants.image_write(kk_long[i]['thickness_image'], kk_file)

    t1_pre_file = output_directory + "/" + t1s_basename_prefix[i] + "Preprocessed.nii.gz"
    ants.image_write(kk_long[i]['preprocessed_image'], t1_pre_file)

    dkt_file = output_directory + "/" + t1s_basename_prefix[i] + "DktLabels.nii.gz"
    dkt = antspynet.desikan_killiany_tourville_labeling(kk_long[i]['preprocessed_image'], do_preprocessing=True, verbose=True)
    ants.image_write(dkt, dkt_file)

    dkt_prop_file = output_directory + "/" + t1s_basename_prefix[i] + "DktPropagatedLabels.nii.gz"
    dkt_mask = ants.threshold_image(dkt, 1000, 3000, 1, 0)
    dkt = dkt_mask * dkt
    ants_tmp = ants.threshold_image(kk_long[i]['thickness_image'], 0, 0, 0, 1)
    ants_dkt = ants.iMath(ants_tmp, "PropagateLabelsThroughMask", ants_tmp * dkt)
    ants.image_write(ants_dkt, dkt_prop_file)
コード例 #19
0
 def test_image_to_cluster_images_example(self):
     image = ants.image_read(ants.get_ants_data('r16'))
     image = ants.threshold_image(image, 1, 1e15)
     image_cluster_list = ants.image_to_cluster_images(image)
コード例 #20
0
def claustrum_segmentation(t1,
                           do_preprocessing=True,
                           use_ensemble=True,
                           antsxnet_cache_directory=None,
                           verbose=False):
    """
    Claustrum segmentation

    Described here:

        https://arxiv.org/abs/2008.03465

    with the implementation available at:

        https://github.com/hongweilibran/claustrum_multi_view


    Arguments
    ---------
    t1 : ANTsImage
        input 3-D T1 brain image.

    do_preprocessing : boolean
        perform n4 bias correction.

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Claustrum segmentation probability image

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> probability_mask = claustrum_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (180, 180)

    ################################
    #
    # Preprocess images
    #
    ################################

    number_of_channels = 1
    t1_preprocessed = ants.image_clone(t1)
    brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        brain_mask = t1_preprocessing["brain_mask"]

    reference_image = ants.make_image((170, 256, 256),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    t1_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, t1_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
    std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
    t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("Claustrum:  retrieving axial model weights.")

    unet_axial_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_axial_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_axial_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_axial_models[i].load_weights(weights_file_name)

    if verbose == True:
        print("Claustrum:  retrieving coronal model weights.")

    unet_coronal_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_coronal_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_coronal_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_coronal_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [1, 2]

    batch_coronal_X = np.zeros(
        (t1_preprocessed_warped.shape[1], *image_size, number_of_channels))
    batch_axial_X = np.zeros(
        (t1_preprocessed_warped.shape[2], *image_size, number_of_channels))

    for d in range(len(dimensions_to_predict)):
        number_of_slices = t1_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            t1_slice = pad_or_crop_image_to_size(
                ants.slice_image(t1_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            if dimensions_to_predict[d] == 1:
                batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1)
            else:
                batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy())

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Coronal prediction.")

    prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X,
                                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_coronal += unet_coronal_models[i].predict(
                batch_coronal_X, verbose=verbose)
    prediction_coronal /= number_of_models

    for i in range(t1_preprocessed_warped.shape[1]):
        prediction_coronal[i, :, :, 0] = np.rot90(
            np.squeeze(prediction_coronal[i, :, :, 0]))

    if verbose == True:
        print("Axial prediction.")

    prediction_axial = unet_axial_models[0].predict(batch_axial_X,
                                                    verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_axial += unet_axial_models[i].predict(batch_axial_X,
                                                             verbose=verbose)
    prediction_axial /= number_of_models

    for i in range(t1_preprocessed_warped.shape[2]):
        prediction_axial[i, :, :,
                         0] = np.rot90(np.squeeze(prediction_axial[i, :, :,
                                                                   0]),
                                       k=-1)

    if verbose == True:
        print("Restack image and transform back to native space.")

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0

    for d in range(len(dimensions_to_predict)):
        which_batch_slices = range(
            t1_preprocessed_warped.shape[dimensions_to_predict[d]])
        prediction_per_dimension = None
        if dimensions_to_predict[d] == 1:
            prediction_per_dimension = prediction_coronal[
                which_batch_slices, :, :, :]
        else:
            prediction_per_dimension = prediction_axial[
                which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            t1_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      t1_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0)

    return (probability_image)
コード例 #21
0
def preprocess_brain_image(image,
                           truncate_intensity=(0.01, 0.99),
                           brain_extraction_modality=None,
                           template_transform_type=None,
                           template="biobank",
                           do_bias_correction=True,
                           return_bias_field=False,
                           do_denoising=True,
                           intensity_matching_type=None,
                           reference_image=None,
                           intensity_normalization_type=None,
                           antsxnet_cache_directory=None,
                           verbose=True):

    """
    Basic preprocessing pipeline for T1-weighted brain MRI

    Standard preprocessing steps that have been previously described
    in various papers including the cortical thickness pipeline:

         https://www.ncbi.nlm.nih.gov/pubmed/24879923

    Arguments
    ---------
    image : ANTsImage
        input image

    truncate_intensity : 2-length tuple
        Defines the quantile threshold for truncating the image intensity

    brain_extraction_modality : string or None
        Perform brain extraction using antspynet tools.  One of "t1", "t1v0",
        "t1nobrainer", "t1combined", "flair", "t2", "bold", "fa", "t1infant",
        "t2infant", or None.

    template_transform_type : string
        See details in help for ants.registration.  Typically "Rigid" or
        "Affine".

    template : ANTs image (not skull-stripped)
        Alternatively, one can specify the default "biobank" or "croppedMni152"
        to download and use premade templates.

    do_bias_correction : boolean
        Perform N4 bias field correction.

    return_bias_field : boolean
        If True, return bias field as an additional output *without* bias
        correcting the preprocessed image.

    do_denoising : boolean
        Perform non-local means denoising.

    intensity_matching_type : string
        Either "regression" or "histogram". Only is performed if reference_image
        is not None.

    reference_image : ANTs image
        Reference image for intensity matching.

    intensity_normalization_type : string
        Either rescale the intensities to [0,1] (i.e., "01") or zero-mean, unit variance
        (i.e., "0mean").  If None normalization is not performed.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary with preprocessing information ANTs image (i.e., source_image) matched to the
    (reference_image).

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> preprocessed_image = preprocess_brain_image(image, do_brain_extraction=False)
    """

    from ..utilities import brain_extraction
    from ..utilities import regression_match_image
    from ..utilities import get_antsxnet_data

    preprocessed_image = ants.image_clone(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    # Truncate intensity
    if truncate_intensity is not None:
        quantiles = (image.quantile(truncate_intensity[0]), image.quantile(truncate_intensity[1]))
        if verbose == True:
            print("Preprocessing:  truncate intensities ( low =", quantiles[0], ", high =", quantiles[1], ").")

        preprocessed_image[image < quantiles[0]] = quantiles[0]
        preprocessed_image[image > quantiles[1]] = quantiles[1]

    # Brain extraction
    mask = None
    if brain_extraction_modality is not None:
        if verbose == True:
            print("Preprocessing:  brain extraction.")

        probability_mask = brain_extraction(preprocessed_image, modality=brain_extraction_modality,
            antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose)
        mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        mask = ants.morphology(mask,"close",6).iMath_fill_holes()

    # Template normalization
    transforms = None
    if template_transform_type is not None:
        template_image = None
        if isinstance(template, str):
            template_file_name_path = get_antsxnet_data(template, antsxnet_cache_directory=antsxnet_cache_directory)
            template_image = ants.image_read(template_file_name_path)
        else:
            template_image = template

        if mask is None:
            registration = ants.registration(fixed=template_image, moving=preprocessed_image,
                type_of_transform=template_transform_type, verbose=verbose)
            preprocessed_image = registration['warpedmovout']
            transforms = dict(fwdtransforms=registration['fwdtransforms'],
                              invtransforms=registration['invtransforms'])
        else:
            template_probability_mask = brain_extraction(template_image, modality=brain_extraction_modality, 
                antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose)
            template_mask = ants.threshold_image(template_probability_mask, 0.5, 1, 1, 0)
            template_brain_image = template_mask * template_image

            preprocessed_brain_image = preprocessed_image * mask

            registration = ants.registration(fixed=template_brain_image, moving=preprocessed_brain_image,
                type_of_transform=template_transform_type, verbose=verbose)
            transforms = dict(fwdtransforms=registration['fwdtransforms'],
                              invtransforms=registration['invtransforms'])

            preprocessed_image = ants.apply_transforms(fixed = template_image, moving = preprocessed_image,
                transformlist=registration['fwdtransforms'], interpolator="linear", verbose=verbose)
            mask = ants.apply_transforms(fixed = template_image, moving = mask,
                transformlist=registration['fwdtransforms'], interpolator="genericLabel", verbose=verbose)

    # Do bias correction
    bias_field = None
    if do_bias_correction == True:
        if verbose == True:
            print("Preprocessing:  brain correction.")
        n4_output = None
        if mask is None:
            n4_output = ants.n4_bias_field_correction(preprocessed_image, shrink_factor=4, return_bias_field=return_bias_field, verbose=verbose)
        else:
            n4_output = ants.n4_bias_field_correction(preprocessed_image, mask, shrink_factor=4, return_bias_field=return_bias_field, verbose=verbose)
        if return_bias_field == True:
            bias_field = n4_output
        else:
            preprocessed_image = n4_output

    # Denoising
    if do_denoising == True:
        if verbose == True:
            print("Preprocessing:  denoising.")

        if mask is None:
            preprocessed_image = ants.denoise_image(preprocessed_image, shrink_factor=1)
        else:
            preprocessed_image = ants.denoise_image(preprocessed_image, mask, shrink_factor=1)

    # Image matching
    if reference_image is not None and intensity_matching_type is not None:
        if verbose == True:
            print("Preprocessing:  intensity matching.")

        if intensity_matching_type == "regression":
            preprocessed_image = regression_match_image(preprocessed_image, reference_image)
        elif intensity_matching_type == "histogram":
            preprocessed_image = ants.histogram_match_image(preprocessed_image, reference_image)
        else:
            raise ValueError("Unrecognized intensity_matching_type.")

    # Intensity normalization
    if intensity_normalization_type is not None:
        if verbose == True:
            print("Preprocessing:  intensity normalization.")

        if intensity_normalization_type == "01":
            preprocessed_image = (preprocessed_image - preprocessed_image.min())/(preprocessed_image.max() - preprocessed_image.min())
        elif intensity_normalization_type == "0mean":
            preprocessed_image = (preprocessed_image - preprocessed_image.mean())/preprocessed_image.std()
        else:
            raise ValueError("Unrecognized intensity_normalization_type.")

    return_dict = {'preprocessed_image' : preprocessed_image}
    if mask is not None:
        return_dict['brain_mask'] = mask
    if bias_field is not None:
        return_dict['bias_field'] = bias_field
    if transforms is not None:
        return_dict['template_transforms'] = transforms

    return(return_dict)
コード例 #22
0
import tensorflow as tf

t1_file = sys.argv[1]
output_prefix = sys.argv[2]
threads = int(sys.argv[3])

tf.keras.backend.clear_session()
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=threads,
                                  inter_op_parallelism_threads=threads)
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)

t1 = ants.image_read(t1_file)
kk = ants.image_read(output_prefix + "CorticalThickness.nii.gz")

# If one wants cortical labels one can run the following lines

dkt = antspynet.desikan_killiany_tourville_labeling(t1,
                                                    do_preprocessing=True,
                                                    verbose=True)
ants.image_write(dkt, output_prefix + "Dkt.nii.gz")

dkt_mask = ants.threshold_image(dkt, 1000, 3000, 1, 0)
dkt = dkt_mask * dkt
ants_tmp = ants.threshold_image(kk, 0, 0, 0, 1)
ants_dkt = ants.iMath(ants_tmp, "PropagateLabelsThroughMask", ants_tmp * dkt)

ants.image_write(ants_dkt, output_prefix + "DktPropagatedLabels.nii.gz")
#
コード例 #23
0
ファイル: sfJointReg.py プロジェクト: stnava/ANTsPyDocker
pts2bold = ants.apply_transforms_to_points( 3, powers_areal_mni_itk, concatx2,whichtoinvert = ( True, False, True, False ) )
locations = pts2bold.iloc[:,:3].values
ptImg = ants.make_points_image( locations, bmask, radius = 3 )
networks = powers_areal_mni_itk['SystemName'].unique()
dfnpts = np.where( powers_areal_mni_itk['SystemName'] == networks[5] )
dfnImg = ants.mask_image(  ptImg, ptImg, level = dfnpts[0].tolist(), binarize=False )

# plot( und, ptImg, axis=3, window.overlay = range( ptImg ) )

bold2ch2 = ants.apply_transforms( ch2, und,  concatx2, whichtoinvert = ( True, False, True, False ) )


# Extracting canonical functional network maps
## preprocessing

csfAndWM = ( ants.threshold_image( boldseg, 1, 1 ) +
             ants.threshold_image( boldseg, 3, 3 ) ).morphology("erode",1)
bold = ants.image_read( boldfnsR )
boldList = ants.ndimage_to_list( bold )
avgBold = ants.get_average_of_timeseries( bold, range( 5 ) )
boldUndTX = ants.registration( und, avgBold, "SyN", regIterations = (15,4),
  synMetric = "CC", synSampling = 2, verbose = False )
boldUndTS = ants.apply_transforms( und, bold, boldUndTX['fwdtransforms'], imagetype = 3  )
motCorr = ants.motion_correction( boldUndTS, avgBold,
    type_of_transform="Rigid", verbose = True )
tr = ants.get_spacing( bold )[3]
highMotionTimes = np.where( motCorr['FD'] >= 0.5 )
goodtimes = np.where( motCorr['FD'] < 0.5 )
avgBold = ants.get_average_of_timeseries( motCorr['motion_corrected'], range( 5 ) )
#######################
nt = len(motCorr['FD'])
コード例 #24
0
def brain_extraction(image,
                     modality="t1",
                     antsxnet_cache_directory=None,
                     verbose=False):
    """
    Perform brain extraction using U-net and ANTs-based training data.  "NoBrainer"
    is also possible where brain extraction uses U-net and FreeSurfer training data
    ported from the

    https://github.com/neuronets/nobrainer-models

    Arguments
    ---------
    image : ANTsImage
        input image (or list of images for multi-modal scenarios).

    modality : string
        Modality image type.  Options include:
            * "t1": T1-weighted MRI---ANTs-trained.  Update from "t1v0".
            * "t1v0":  T1-weighted MRI---ANTs-trained.
            * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk.
            * "t1combined": Brian's combination of "t1" and "t1nobrainer".  One can also specify
                            "t1combined[X]" where X is the morphological radius.  X = 12 by default.
            * "flair": FLAIR MRI.
            * "t2": T2 MRI.
            * "bold": 3-D BOLD MRI.
            * "fa": Fractional anisotropy.
            * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner.
            * "t1infant": T1-w infant MRI h/t Martin Styner.
            * "t2infant": T2-w infant MRI h/t Martin Styner.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs probability brain mask image.

    Example
    -------
    >>> probability_brain_mask = brain_extraction(brain_image, modality="t1")
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..architectures import create_nobrainer_unet_model_3d

    classes = ("background", "brain")
    number_of_classification_labels = len(classes)

    channel_size = 1
    if isinstance(image, list):
        channel_size = len(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    input_images = list()
    if channel_size == 1:
        input_images.append(image)
    else:
        input_images = image

    if input_images[0].dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if "t1combined" in modality:

        brain_extraction_t1 = brain_extraction(
            image,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_mask = ants.iMath_get_largest_component(
            ants.threshold_image(brain_extraction_t1, 0.5, 10000))

        # Need to change with voxel resolution
        morphological_radius = 12
        if '[' in modality and ']' in modality:
            morphological_radius = int(modality.split("[")[1].split("]")[0])

        brain_extraction_t1nobrainer = brain_extraction(
            image * ants.iMath_MD(brain_mask, radius=morphological_radius),
            modality="t1nobrainer",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_extraction_combined = ants.iMath_fill_holes(
            ants.iMath_get_largest_component(brain_extraction_t1nobrainer *
                                             brain_mask))

        brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(
            brain_mask, morphological_radius) + brain_mask

        return (brain_extraction_combined)

    if modality != "t1nobrainer":

        #####################
        #
        # ANTs-based
        #
        #####################

        weights_file_name_prefix = None

        if modality == "t1v0":
            weights_file_name_prefix = "brainExtraction"
        elif modality == "t1":
            weights_file_name_prefix = "brainExtractionT1"
        elif modality == "t2":
            weights_file_name_prefix = "brainExtractionT2"
        elif modality == "flair":
            weights_file_name_prefix = "brainExtractionFLAIR"
        elif modality == "bold":
            weights_file_name_prefix = "brainExtractionBOLD"
        elif modality == "fa":
            weights_file_name_prefix = "brainExtractionFA"
        elif modality == "t1t2infant":
            weights_file_name_prefix = "brainExtractionInfantT1T2"
        elif modality == "t1infant":
            weights_file_name_prefix = "brainExtractionInfantT1"
        elif modality == "t2infant":
            weights_file_name_prefix = "brainExtractionInfantT2"
        else:
            raise ValueError("Unknown modality type.")

        weights_file_name = get_pretrained_network(
            weights_file_name_prefix,
            antsxnet_cache_directory=antsxnet_cache_directory)

        if verbose == True:
            print("Brain extraction:  retrieving template.")
        reorient_template_file_name_path = get_antsxnet_data(
            "S_template3", antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)
        resampled_image_size = reorient_template.shape

        if modality == "t1":
            classes = ("background", "head", "brain")
            number_of_classification_labels = len(classes)

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Brain extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template)
        center_of_mass_image = ants.get_center_of_mass(input_images[0])
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)

        batchX = np.zeros((1, *resampled_image_size, channel_size))
        for i in range(len(input_images)):
            warped_image = ants.apply_ants_transform_to_image(
                xfrm, input_images[i], reorient_template)
            warped_array = warped_image.numpy()
            batchX[0, :, :, :, i] = (warped_array -
                                     warped_array.mean()) / warped_array.std()

        if verbose == True:
            print("Brain extraction:  prediction and decoding.")

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = reorient_template.origin
        spacing = reorient_template.spacing
        direction = reorient_template.direction

        probability_images_array = list()
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 0]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 1]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        if modality == "t1":
            probability_images_array.append(
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 2]),
                                origin=origin,
                                spacing=spacing,
                                direction=direction))

        if verbose == True:
            print(
                "Brain extraction:  renormalize probability mask to native space."
            )
        probability_image = ants.apply_ants_transform_to_image(
            ants.invert_ants_transform(xfrm),
            probability_images_array[number_of_classification_labels - 1],
            input_images[0])

        return (probability_image)

    else:

        #####################
        #
        # NoBrainer
        #
        #####################

        if verbose == True:
            print("NoBrainer:  generating network.")

        model = create_nobrainer_unet_model_3d((None, None, None, 1))

        weights_file_name = get_pretrained_network(
            "brainExtractionNoBrainer",
            antsxnet_cache_directory=antsxnet_cache_directory)
        model.load_weights(weights_file_name)

        if verbose == True:
            print(
                "NoBrainer:  preprocessing (intensity truncation and resampling)."
            )

        image_array = image.numpy()
        image_robust_range = np.quantile(
            image_array[np.where(image_array != 0)], (0.02, 0.98))
        threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]
                                  ) + image_robust_range[0]

        thresholded_mask = ants.threshold_image(image, -10000, threshold_value,
                                                0, 1)
        thresholded_image = image * thresholded_mask

        image_resampled = ants.resample_image(thresholded_image,
                                              (256, 256, 256),
                                              use_voxels=True)
        image_array = np.expand_dims(image_resampled.numpy(), axis=0)
        image_array = np.expand_dims(image_array, axis=-1)

        if verbose == True:
            print("NoBrainer:  predicting mask.")

        brain_mask_array = np.squeeze(
            model.predict(image_array, verbose=verbose))
        brain_mask_resampled = ants.copy_image_info(
            image_resampled, ants.from_numpy(brain_mask_array))
        brain_mask_image = ants.resample_image(brain_mask_resampled,
                                               image.shape,
                                               use_voxels=True,
                                               interp_type=1)

        spacing = ants.get_spacing(image)
        spacing_product = spacing[0] * spacing[1] * spacing[2]
        minimum_brain_volume = round(649933.7 / spacing_product)
        brain_mask_labeled = ants.label_clusters(brain_mask_image,
                                                 minimum_brain_volume)

        return (brain_mask_labeled)