예제 #1
0
    def __bias_correction(self):
        logger.info("performing N4 bias correction")
        print("performing N4 bias correction")
        if self._t1file != None and self._t2file != None:
            self._t1_n4 = ants.iMath(
                self._t1_reg['warpedmovout'].abp_n4(usen3=self._usen3),
                "Normalize") * 100
            self._t2_n4 = ants.iMath(self._t2_reg.abp_n4(usen3=self._usen3),
                                     "Normalize") * 100
            ants.image_write(
                self._t1_n4,
                os.path.join(self._outputdir, self._id + '_t1_final.nii.gz'))
            ants.image_write(
                self._t2_n4,
                os.path.join(self._outputdir, self._id + '_t2_final.nii.gz'))

        if self._t1file != None and self._t2file == None:
            self._t1_n4 = ants.iMath(
                self._t1_reg['warpedmovout'].abp_n4(usen3=self._usen3),
                "Normalize") * 100
            ants.image_write(
                self._t1_n4,
                os.path.join(self._outputdir, self._id + '_t1_final.nii.gz'))

        if self._t2file != None and self._t1file == None:
            self._t2_n4 = ants.iMath(
                self._t2_reg['warpedmovout'].abp_n4(usen3=self._usen3),
                "Normalize") * 100
            ants.image_write(
                self._t2_n4,
                os.path.join(self._outputdir, self._id + '_t2_final.nii.gz'))
예제 #2
0
    def test_example(self):
        ref = ants.image_read(ants.get_ants_data('r16'))
        ref = ants.resample_image(ref, (50, 50), 1, 0)
        ref = ants.iMath(ref, 'Normalize')
        mi = ants.image_read(ants.get_ants_data('r27'))
        mi2 = ants.image_read(ants.get_ants_data('r30'))
        mi3 = ants.image_read(ants.get_ants_data('r62'))
        mi4 = ants.image_read(ants.get_ants_data('r64'))
        mi5 = ants.image_read(ants.get_ants_data('r85'))
        refmask = ants.get_mask(ref)
        refmask = ants.iMath(refmask, 'ME', 2)  # just to speed things up
        ilist = [mi, mi2, mi3, mi4, mi5]
        seglist = [None] * len(ilist)
        for i in range(len(ilist)):
            ilist[i] = ants.iMath(ilist[i], 'Normalize')
            mytx = ants.registration(fixed=ref,
                                     moving=ilist[i],
                                     typeofTransform=('Affine'))
            mywarpedimage = ants.apply_transforms(
                fixed=ref,
                moving=ilist[i],
                transformlist=mytx['fwdtransforms'])
            ilist[i] = mywarpedimage
            seg = ants.threshold_image(ilist[i], 'Otsu', 3)
            seglist[i] = (seg) + ants.threshold_image(seg, 1, 3).morphology(
                operation='dilate', radius=3)

        r = 2
        pp = ants.joint_label_fusion(ref,
                                     refmask,
                                     ilist,
                                     r_search=2,
                                     label_list=seglist,
                                     rad=[r] * ref.dimension)
        pp = ants.joint_label_fusion(ref, refmask, ilist, r_search=2, rad=2)
예제 #3
0
    def __gradient_magnitude(self):
        logger.info("computing gradient magnitude")
        print("computing gradient magnitude")
        if self._t1file != None and self._t2file != None:
            self._grad_t1 = ants.iMath(self._t1_n4, "Grad", 1)
            # self._grad_t2 = ants.iMath(self._t2_n4, "Grad", 1)
            ants.image_write(
                self._grad_t1,
                os.path.join(self._outputdir,
                             self._id + '_t1_gradient_magnitude.nii.gz'))
            # ants.image_write( self._grad_t1, os.path.join(self._outputdir, self._id+'_t2_gradient_magnitude.nii.gz'))

        if self._t1file != None and self._t2file == None:
            self._grad_t1 = ants.iMath(self._t1_n4, "Grad", 1)
            ants.image_write(
                self._grad_t1,
                os.path.join(self._outputdir,
                             self._id + '_t1_gradient_magnitude.nii.gz'))
def gmsd(x,
         y):
    """
    Gradient magnitude similarity deviation

    A fast and simple metric that correlates to perceptual quality.

    Arguments
    ---------
    x : input image
        ants input image

    y : input image
        ants input image

    Returns
    -------
    Value

    Example
    -------
    >>> r16 = ants.image_read(ants.get_data("r16"))
    >>> r64 = ants.image_read(ants.get_data("r64"))
    >>> value = gmsd(r16, r64)
    """

    gx = ants.iMath(x, "Grad")
    gy = ants.iMath(y, "Grad")

    # see eqn 4 - 6 in https://arxiv.org/pdf/1308.3052.pdf

    constant = 0.0026
    gmsd_numerator = gx * gy * 2.0 + constant
    gmsd_denominator = gx**2 + gy**2 + constant
    gmsd = gmsd_numerator / gmsd_denominator

    product_dimension = 1
    for i in range(len(x.shape)):
       product_dimension *= x.shape[i]
    prefactor = 1.0 / product_dimension

    return(np.sqrt(prefactor * ((gmsd - gmsd.mean())**2).sum()))
예제 #5
0
    def test_multiple_inputs(self):
        img = ants.image_read(ants.get_ants_data("r16"))
        img = ants.resample_image(img, (64, 64), 1, 0)
        mask = ants.get_mask(img)
        segs1 = ants.atropos(a=img,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i='kmeans[3]',
                             x=mask)

        # Use probabilities from k-means seg as priors
        segs2 = ants.atropos(a=img,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i=segs1['probabilityimages'],
                             x=mask)

        # multiple inputs
        feats = [img, ants.iMath(img, "Laplacian"), ants.iMath(img, "Grad")]
        segs3 = ants.atropos(a=feats,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i=segs1['probabilityimages'],
                             x=mask)
        "mriSuperResolution", weights_file_name)

model.load_weights(weights_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

number_of_image_volumes = len(input_image_list)

output_image_list = list()
for i in range(number_of_image_volumes):
    print("Applying super resolution to image", i, "of",
          number_of_image_volumes)
    start_time = time.time()

    input_image = ants.iMath(input_image_list[i], "TruncateIntensity", 0.0001,
                             0.995)
    output_sr = antspynet.apply_super_resolution_model_to_image(
        input_image, model, target_range=(127.5, -127.5))
    input_image_resampled = ants.resample_image_to_target(
        input_image, output_sr)
    output_image_list.append(
        antspynet.regression_match_image(output_sr,
                                         input_image_resampled,
                                         poly_order=2))

    end_time = time.time()
    elapsed_time = end_time - start_time
    print("   (elapsed time:", elapsed_time, "seconds)")

print("Writing output image.")
if number_of_image_volumes == 1:
예제 #7
0
파일: dataset.py 프로젝트: alexgalayda/cAAE
 def get_mask(self):
     mask = ants.get_mask(self.get_brain())
     mask = ants.iMath(mask, 'ME', 2)
     return mask
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        do_lobar_parcellation=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Performing the lobar parcellation is based on the FreeSurfer division
    described here:

    See https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation

    Frontal lobe:
    Label 1002:  left caudal anterior cingulate
    Label 1003:  left caudal middle frontal
    Label 1012:  left lateral orbitofrontal
    Label 1014:  left medial orbitofrontal
    Label 1017:  left paracentral
    Label 1018:  left pars opercularis
    Label 1019:  left pars orbitalis
    Label 1020:  left pars triangularis
    Label 1024:  left precentral
    Label 1026:  left rostral anterior cingulate
    Label 1027:  left rostral middle frontal
    Label 1028:  left superior frontal
    Label 2002:  right caudal anterior cingulate
    Label 2003:  right caudal middle frontal
    Label 2012:  right lateral orbitofrontal
    Label 2014:  right medial orbitofrontal
    Label 2017:  right paracentral
    Label 2018:  right pars opercularis
    Label 2019:  right pars orbitalis
    Label 2020:  right pars triangularis
    Label 2024:  right precentral
    Label 2026:  right rostral anterior cingulate
    Label 2027:  right rostral middle frontal
    Label 2028:  right superior frontal

    Parietal:
    Label 1008:  left inferior parietal
    Label 1010:  left isthmus cingulate
    Label 1022:  left postcentral
    Label 1023:  left posterior cingulate
    Label 1025:  left precuneus
    Label 1029:  left superior parietal
    Label 1031:  left supramarginal
    Label 2008:  right inferior parietal
    Label 2010:  right isthmus cingulate
    Label 2022:  right postcentral
    Label 2023:  right posterior cingulate
    Label 2025:  right precuneus
    Label 2029:  right superior parietal
    Label 2031:  right supramarginal

    Temporal:
    Label 1006:  left entorhinal
    Label 1007:  left fusiform
    Label 1009:  left inferior temporal
    Label 1015:  left middle temporal
    Label 1016:  left parahippocampal
    Label 1030:  left superior temporal
    Label 1034:  left transverse temporal
    Label 2006:  right entorhinal
    Label 2007:  right fusiform
    Label 2009:  right inferior temporal
    Label 2015:  right middle temporal
    Label 2016:  right parahippocampal
    Label 2030:  right superior temporal
    Label 2034:  right transverse temporal

    Occipital:
    Label 1005:  left cuneus
    Label 1011:  left lateral occipital
    Label 1013:  left lingual
    Label 1021:  left pericalcarine
    Label 2005:  right cuneus
    Label 2011:  right lateral occipital
    Label 2013:  right lingual
    Label 2021:  right pericalcarine

    Other outer labels:
    Label 1035:  left insula
    Label 2035:  right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    do_lobar_parcellation : boolean
        Perform lobar parcellation (also divided by hemisphere).

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> dkt = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import deep_atropos

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    template_transform_type = "antsRegistrationSyNQuickRepro[a]"
    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type=template_transform_type,
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if do_lobar_parcellation:

        if verbose == True:
            print("Doing lobar parcellation.")

        ################################
        #
        # Lobar/hemisphere parcellation
        #
        ################################

        # Consolidate lobar cortical labels

        if verbose == True:
            print("   Consolidating cortical labels.")

        frontal_labels = (1002, 1003, 1012, 1014, 1017, 1018, 1019, 1020, 1024,
                          1026, 1027, 1028, 2002, 2003, 2012, 2014, 2017, 2018,
                          2019, 2020, 2024, 2026, 2027, 2028)
        parietal_labels = (1008, 1010, 1022, 1023, 1025, 1029, 1031, 2008,
                           2010, 2022, 2023, 2025, 2029, 2031)
        temporal_labels = (1006, 1007, 1009, 1015, 1016, 1030, 1034, 2006,
                           2007, 2009, 2015, 2016, 2030, 2034)
        occipital_labels = (1005, 1011, 1013, 1021, 2005, 2011, 2013, 2021)

        lobar_labels = list()
        lobar_labels.append(frontal_labels)
        lobar_labels.append(parietal_labels)
        lobar_labels.append(temporal_labels)
        lobar_labels.append(occipital_labels)

        dkt_lobes = ants.image_clone(dkt_label_image)
        dkt_lobes[dkt_lobes < 1000] = 0

        for i in range(len(lobar_labels)):
            for j in range(len(lobar_labels[i])):
                dkt_lobes[dkt_lobes == lobar_labels[i][j]] = i + 1

        dkt_lobes[dkt_lobes > len(lobar_labels)] = 0

        six_tissue = deep_atropos(
            t1_preprocessed,
            do_preprocessing=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        atropos_seg = six_tissue['segmentation_image']
        if do_preprocessing == True:
            atropos_seg = ants.apply_transforms(
                fixed=t1,
                moving=atropos_seg,
                transformlist=t1_preprocessing['template_transforms']
                ['invtransforms'],
                whichtoinvert=[True],
                interpolator="genericLabel",
                verbose=verbose)

        brain_mask = ants.image_clone(atropos_seg)
        brain_mask[brain_mask == 1 or brain_mask == 5 or brain_mask == 6] = 0
        brain_mask = ants.threshold_image(brain_mask, 0, 0, 0, 1)

        lobar_parcellation = ants.iMath(brain_mask,
                                        "PropagateLabelsThroughMask",
                                        brain_mask * dkt_lobes)

        lobar_parcellation[atropos_seg == 5] = 5
        lobar_parcellation[atropos_seg == 6] = 6

        # Do left/right

        if verbose == True:
            print("   Doing left/right hemispheres.")

        left_labels = (*tuple(range(4, 8)), *tuple(range(10, 14)), 17, 18, 25,
                       26, 28, 30, 91, 1002, 1003, *tuple(range(1005, 1032)),
                       1034, 1035)
        right_labels = (*tuple(range(43, 47)), *tuple(range(49, 55)), 57, 58,
                        60, 62, 92, 2002, 2003, *tuple(range(2005, 2032)),
                        2034, 2035)

        hemisphere_labels = list()
        hemisphere_labels.append(left_labels)
        hemisphere_labels.append(right_labels)

        dkt_hemispheres = ants.image_clone(dkt_label_image)

        for i in range(len(hemisphere_labels)):
            for j in range(len(hemisphere_labels[i])):
                dkt_hemispheres[dkt_hemispheres == hemisphere_labels[i]
                                [j]] = i + 1

        dkt_hemispheres[dkt_hemispheres > 2] = 0

        atropos_brain_mask = ants.threshold_image(atropos_seg, 0, 0, 0, 1)
        hemisphere_parcellation = ants.iMath(
            atropos_brain_mask, "PropagateLabelsThroughMask",
            atropos_brain_mask * dkt_hemispheres)

        # The following contains a bug somewhere as only the latter condition is seen.
        # Need to fix it.
        #
        # for i in range(6):
        #     lobar_parcellation[lobar_parcellation == (i + 1) and hemisphere_parcellation == 2] = 6 + i + 1

        hemisphere_parcellation *= ants.threshold_image(
            lobar_parcellation, 0, 0, 0, 1)
        hemisphere_parcellation[hemisphere_parcellation == 1] = 0
        hemisphere_parcellation[hemisphere_parcellation == 2] = 1
        hemisphere_parcellation *= 6
        lobar_parcellation += hemisphere_parcellation

    if return_probability_images == True and do_lobar_parcellation == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'lobar_parcellation': lobar_parcellation,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    elif return_probability_images == True and do_lobar_parcellation == False:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    elif return_probability_images == False and do_lobar_parcellation == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'lobar_parcellation': lobar_parcellation
        }
        return (return_dict)
    else:
        return (dkt_label_image)
예제 #9
0
def histogram_warp_image_intensities(image,
                                     break_points=(0.25, 0.5, 0.75),
                                     displacements=None,
                                     clamp_end_points=(False, False),
                                     sd_displacements=0.05,
                                     transform_domain_size=20):
    """
    Transform image intensities based on histogram mapping.

    Apply B-spline 1-D maps to an input image for intensity warping.

    Arguments
    ---------
    image : ANTsImage
        Input image.

    break_points : integer or tuple
        Parametric points at which the intensity transform displacements
        are specified between [0, 1].  Alternatively, a single number can
        be given and the sequence is linearly spaced in [0, 1].

    displacements : tuple
        displacements to define intensity warping.  Length must be equal to the
        breakPoints.  Alternatively, if None random displacements are chosen
        (random normal:  mean = 0, sd = sd_displacements).

    sd_displacements : float
        Characterize the randomness of the intensity displacement.

    clamp_end_points : 2-element tuple of booleans
        Specify non-zero intensity change at the ends of the histogram.

    transform_domain_size : integer
        Defines the sampling resolution of the B-spline warping.

    Returns
    -------
    ANTs image

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data("r64"))
    >>> transformed_image = histogram_warp_image_intensities( image )
    """

    if not len(clamp_end_points) == 2:
        raise ValueError(
            "clamp_end_points must be a boolean tuple of length 2.")

    if not isinstance(break_points, int):
        if any(b < 0 for b in break_points) and any(b > 1
                                                    for b in break_points):
            raise ValueError(
                "If specifying break_points as a vector, values must be in the range [0, 1]"
            )

    parametric_points = None
    number_of_nonzero_displacements = 1
    if not isinstance(break_points, int):
        parametric_points = break_points
        number_of_nonzero_displacements = len(break_points)
        if clamp_end_points[0] is True:
            parametric_points = (0, *parametric_points)
        if clamp_end_points[1] is True:
            parametric_points = (*parametric_points, 1)
    else:
        total_number_of_break_points = break_points
        if clamp_end_points[0] is True:
            total_number_of_break_points += 1
        if clamp_end_points[1] is True:
            total_number_of_break_points += 1
        parametric_points = np.linspace(0, 1, total_number_of_break_points)
        number_of_nonzero_displacements = break_points

    if displacements is None:
        displacements = np.random.normal(loc=0.0,
                                         scale=sd_displacements,
                                         size=number_of_nonzero_displacements)

    weights = np.ones(len(displacements))
    if clamp_end_points[0] is True:
        displacements = (0, *displacements)
        weights = np.concatenate((1000 * np.ones(1), weights))
    if clamp_end_points[1] is True:
        displacements = (*displacements, 0)
        weights = np.concatenate((weights, 1000 * np.ones(1)))

    if not len(displacements) == len(parametric_points):
        raise ValueError(
            "Length of displacements does not match the length of the break points."
        )

    scattered_data = np.reshape(displacements, (len(displacements), 1))
    parametric_data = np.reshape(parametric_points,
                                 (len(parametric_points), 1))

    transform_domain_origin = 0
    transform_domain_spacing = (1.0 - transform_domain_origin) / (
        transform_domain_size - 1)

    bspline_histogram_transform = ants.fit_bspline_object_to_scattered_data(
        scattered_data,
        parametric_data, [transform_domain_origin], [transform_domain_spacing],
        [transform_domain_size],
        data_weights=weights)

    transform_domain = np.linspace(0, 1, transform_domain_size)

    normalized_image = ants.iMath(ants.image_clone(image), "Normalize")
    transformed_array = normalized_image.numpy()
    normalized_array = normalized_image.numpy()

    for i in range(len(transform_domain) - 1):
        indices = np.where((normalized_array >= transform_domain[i])
                           & (normalized_array < transform_domain[i + 1]))
        intensities = normalized_array[indices]

        alpha = (intensities - transform_domain[i]) / (
            transform_domain[i + 1] - transform_domain[i])
        xfrm = alpha * (
            bspline_histogram_transform[i + 1] -
            bspline_histogram_transform[i]) + bspline_histogram_transform[i]
        transformed_array[indices] = intensities + xfrm

    transformed_image = (ants.from_numpy(transformed_array,
                                         origin=image.origin,
                                         spacing=image.spacing,
                                         direction=image.direction) *
                         (image.max() - image.min())) + image.min()

    return (transformed_image)
예제 #10
0
import tensorflow as tf

t1_file = sys.argv[1]
output_prefix = sys.argv[2]
threads = int(sys.argv[3])

tf.keras.backend.clear_session()
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=threads,
                                  inter_op_parallelism_threads=threads)
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)

t1 = ants.image_read(t1_file)
kk = ants.image_read(output_prefix + "CorticalThickness.nii.gz")

# If one wants cortical labels one can run the following lines

dkt = antspynet.desikan_killiany_tourville_labeling(t1,
                                                    do_preprocessing=True,
                                                    verbose=True)
ants.image_write(dkt, output_prefix + "Dkt.nii.gz")

dkt_mask = ants.threshold_image(dkt, 1000, 3000, 1, 0)
dkt = dkt_mask * dkt
ants_tmp = ants.threshold_image(kk, 0, 0, 0, 1)
ants_dkt = ants.iMath(ants_tmp, "PropagateLabelsThroughMask", ants_tmp * dkt)

ants.image_write(ants_dkt, output_prefix + "DktPropagatedLabels.nii.gz")
#
예제 #11
0
def tid_neural_image_assessment(image,
                                mask=None,
                                patch_size=101,
                                stride_length=None,
                                padding_size=0,
                                dimensions_to_predict=0,
                                antsxnet_cache_directory=None,
                                which_model="tidsQualityAssessment",
                                image_scaling=[255, 127.5],
                                do_patch_scaling=False,
                                no_reconstruction=False,
                                verbose=False):
    """
    Perform MOS-based assessment of an image.

    Use a ResNet architecture to estimate image quality in 2D or 3D using subjective
    QC image databases described in

    https://www.sciencedirect.com/science/article/pii/S0923596514001490

    or

    https://doi.org/10.1109/TIP.2020.2967829

    where the image assessment is either "global", i.e., a single number or an image
    based on the specified patch size.  In the 3-D case, neighboring slices are used
    for each estimate.  Note that parameters should be kept as consistent as possible
    in order to enable comparison.  Patch size should be roughly 1/12th to 1/4th of
    image size to enable locality. A global estimate can be gained by setting
    patch_size = "global".

    Arguments
    ---------
    image : ANTsImage (2-D or 3-D)
        input image.

    mask : ANTsImage (2-D or 3-D)
        optional mask for designating calculation ROI.

    patch_size : integer
        prime number of patch_size.  101 is good.  Otherwise, choose "global" for a single
        global estimate of quality.

    stride_length : integer or vector of image dimension length
        optional value to speed up computation (typically less than patch size).

    padding_size : positive or negative integer or vector of image dimension length
        de(padding) to remove edge effects.

    dimensions_to_predict : integer or vector
        if image dimension is 3, this parameter specifies which dimensions should be used for
        prediction.  If more than one dimension is specified, the results are averaged.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to
        ~/.keras/ANTsXNet/.

    which_model : string or tf/keras model
        model type e.g. string tidsQualityAssessment, koniqMS, koniqMS2 or koniqMS3 where
        the former predicts mean opinion score (MOS) and MOS standard deviation and
        the latter koniq models predict mean opinion score (MOS) and sharpness.
        passing a user-defined model is also valid.

    image_scaling : a two-vector where the first value is the multiplier and the
        second value the subtractor so each image will be scaled as
        img = ants.iMath(img,"Normalize")*m  - s.

    do_patch_scaling :boolean controlling whether each patch is scaled or
        (if False) only a global scaling of the image is used.

    no_reconstruction : boolean reconstruction is time consuming - turn this on
        if you just want the predicted values

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List of QC results predicting both both human rater's mean and standard
    deviation of the MOS ("mean opinion scores") or sharpness depending on the
    selected network.  Both aggregate and spatial scores are returned, the latter
    in the form of an image.

    Example
    -------
    >>> image = ants.image_read(ants.get_data("r16"))
    >>> mask = ants.get_mask(image)
    >>> tid = tid_neural_image_assessment(image, mask=mask, patch_size=101, stride_length=7)
    """

    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches

    def is_prime(n):
        if n == 2 or n == 3:
            return True
        if n < 2 or n % 2 == 0:
            return False
        if n < 9:
            return True
        if n % 3 == 0:
            return False
        r = int(n**0.5)
        f = 5
        while f <= r:
            if n % f == 0:
                return False
            if n % (f + 2) == 0:
                return False
            f += 6
        return True

    if type(which_model) is not type("x"):
        tid_model = which_model  # should be a tf model
        which_model = "user_defined"

    valid_models = ("tidsQualityAssessment", "koniqMS", "koniqMS2", "koniqMS3",
                    "user_defined")
    if not which_model in valid_models:
        raise ValueError("Please pass valid model")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if verbose == True:
        print("Neural QA:  retreiving model and weights.")

    is_koniq = "koniq" in which_model
    if which_model != "user_defined":
        model_and_weights_file_name = get_pretrained_network(
            which_model, antsxnet_cache_directory=antsxnet_cache_directory)
        tid_model = tf.keras.models.load_model(model_and_weights_file_name,
                                               compile=False)

    padding_size_vector = padding_size
    if isinstance(padding_size, int):
        padding_size_vector = np.repeat(padding_size, image.dimension)
    elif len(padding_size) == 1:
        padding_size_vector = np.repeat(padding_size[0], image.dimension)

    if isinstance(dimensions_to_predict, int):
        dimensions_to_predict = (dimensions_to_predict, )

    padded_image_size = image.shape + padding_size_vector
    padded_image = pad_or_crop_image_to_size(image, padded_image_size)

    number_of_channels = 3

    if stride_length is None and patch_size != "global":
        stride_length = round(patch_size / 2)
        if image.dimension == 3:
            stride_length = (stride_length, stride_length, 1)

    ###############
    #
    #  Global
    #
    ###############
    if which_model == "tidsQualityAssessment":
        evaluation_image = ants.iMath(padded_image, "Normalize") * 255

    if is_koniq:
        evaluation_image = ants.iMath(padded_image, "Normalize") * 2.0 - 1.0

    if which_model == "user_defined":
        evaluation_image = ants.iMath(
            padded_image, "Normalize") * image_scaling[0] - image_scaling[1]

    if patch_size == 'global':

        if image.dimension == 2:
            batchX = np.zeros((1, evaluation_image.shape, number_of_channels))
            for k in range(3):
                batchX[0, :, :, k] = evaluation_image.numpy()
            predicted_data = tid_model.predict(batchX, verbose=verbose)

            if which_model == "tidsQualityAssessment":
                return_dict = {
                    'MOS': None,
                    'MOS.standardDeviation': None,
                    'MOS.mean': predicted_data[0, 0],
                    'MOS.standardDeviationMean': predicted_data[0, 1]
                }
                return (return_dict)

            elif is_koniq or which_model == "user_defined":
                return_dict = {
                    'MOS.mean': predicted_data[0, 0],
                    'sharpness.mean': predicted_data[0, 1]
                }
                return (return_dict)

        elif image.dimension == 3:
            mos_mean = 0
            mos_standard_deviation = 0
            x = tuple(range(image.dimension))
            d = 0
            if True:
                #            for d in 0: # range(len(dimensions_to_predict)):
                not_padded_image_size = list(padded_image_size)
                del (not_padded_image_size[dimensions_to_predict[d]])
                newsize = not_padded_image_size
                newsize.insert(0, padded_image_size[dimensions_to_predict[d]])
                newsize.append(number_of_channels)
                batchX = np.zeros(newsize)
                for k in range(3):
                    batchX[:, :, :, k] = evaluation_image.numpy()
                predicted_data = tid_model.predict(batchX, verbose=verbose)
                mos_mean += predicted_data[0, 0]
                mos_standard_deviation += predicted_data[0, 1]

            mos_mean /= len(dimensions_to_predict)
            mos_standard_deviation /= len(dimensions_to_predict)
            if which_model == "tidsQualityAssessment":
                return_dict = {
                    'MOS.mean': mos_mean,
                    'MOS.standardDeviationMean': mos_standard_deviation
                }
                return (return_dict)
            else:
                return_dict = {
                    'MOS.mean': mos_mean,
                    'sharpness.mean': mos_standard_deviation
                }
                return (return_dict)

    ###############
    #
    #  Patchwise
    #
    ###############

    else:

        # if not is_prime(patch_size):
        #    print("patch_size should be a prime number:  13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97...")

        stride_length_vector = stride_length
        if isinstance(stride_length, int):
            if image.dimension == 2:
                stride_length_vector = (stride_length, stride_length)
        elif len(stride_length) == 1:
            if image.dimension == 2:
                stride_length_vector = (stride_length[0], stride_length[0])

        patch_size_vector = (patch_size, patch_size)

        if image.dimension == 2:
            dimensions_to_predict = (1, )

        permutations = list()

        mos = image * 0
        mos_standard_deviation = image * 0

        for d in range(len(dimensions_to_predict)):
            if image.dimension == 3:
                permutations.append((0, 1, 2))
                permutations.append((0, 2, 1))
                permutations.append((1, 2, 0))

                if dimensions_to_predict[d] == 0:
                    patch_size_vector = (patch_size, patch_size,
                                         number_of_channels)
                    if isinstance(stride_length, int):
                        stride_length_vector = (stride_length, stride_length,
                                                1)
                elif dimensions_to_predict[d] == 1:
                    patch_size_vector = (patch_size, number_of_channels,
                                         patch_size)
                    if isinstance(stride_length, int):
                        stride_length_vector = (stride_length, 1,
                                                stride_length)
                elif dimensions_to_predict[d] == 2:
                    patch_size_vector = (number_of_channels, patch_size,
                                         patch_size)
                    if isinstance(stride_length, int):
                        stride_length_vector = (1, stride_length,
                                                stride_length)
                else:
                    raise ValueError(
                        "dimensions_to_predict elements should be 1, 2, and/or 3 for 3-D image."
                    )

            if mask is None:
                patches = extract_image_patches(
                    evaluation_image,
                    patch_size=patch_size_vector,
                    stride_length=stride_length_vector,
                    return_as_array=False)
            else:
                patches = extract_image_patches(evaluation_image,
                                                patch_size=patch_size_vector,
                                                max_number_of_patches=int(
                                                    (mask == 1).sum()),
                                                return_as_array=False,
                                                mask_image=mask,
                                                randomize=False)

            batchX = np.zeros(
                (len(patches), patch_size, patch_size, number_of_channels))

            verbose = False
            if verbose:
                print("Predict begin")

            is_good_patch = np.repeat(False, len(patches))
            for i in range(len(patches)):
                if patches[i].var() > 0:
                    is_good_patch[i] = True
                    patch_image = patches[i]
                    patch_image = patch_image - patch_image.min()

                    if patch_image.max() > 0:
                        if which_model == "tidsQualityAssessment" and do_patch_scaling:
                            patch_image = patch_image / patch_image.max() * 255
                        elif is_koniq and do_patch_scaling:
                            patch_image = patch_image / patch_image.max(
                            ) * 2.0 - 1.0
                        elif which_model == "user_defined" and do_patch_scaling:
                            patch_image = patch_image / patch_image.max(
                            ) * image_scaling[0] - image_scaling[1]

                    if image.dimension == 2:
                        for j in range(number_of_channels):
                            batchX[i, :, :, j] = patch_image
                    elif image.dimension == 3:
                        batchX[i, :, :, :] = np.transpose(
                            np.squeeze(patch_image),
                            permutations[dimensions_to_predict[d]])

            good_batchX = batchX[is_good_patch, :, :, :]
            predicted_data = tid_model.predict(good_batchX, verbose=verbose)

            if no_reconstruction:
                return predicted_data

            if verbose:
                print("Predict done")

            patches_mos = list()
            patches_mos_standard_deviation = list()

            zero_patch_image = patch_image * 0

            count = 0
            for i in range(len(patches)):
                if is_good_patch[i]:
                    patches_mos.append(zero_patch_image +
                                       predicted_data[count, 0])
                    patches_mos_standard_deviation.append(zero_patch_image +
                                                          predicted_data[count,
                                                                         1])
                    count += 1
                else:
                    patches_mos.append(zero_patch_image)
                    patches_mos_standard_deviation.append(zero_patch_image)

            if verbose:
                print("reconstruct")

            if mask is None:
                mos += pad_or_crop_image_to_size(
                    reconstruct_image_from_patches(
                        patches_mos,
                        evaluation_image,
                        stride_length=stride_length_vector), image.shape)
                mos_standard_deviation += pad_or_crop_image_to_size(
                    reconstruct_image_from_patches(
                        patches_mos_standard_deviation,
                        evaluation_image,
                        stride_length=stride_length_vector), image.shape)
            else:
                mos += pad_or_crop_image_to_size(
                    reconstruct_image_from_patches(patches_mos,
                                                   mask,
                                                   domain_image_is_mask=True),
                    image.shape)
                mos_standard_deviation += pad_or_crop_image_to_size(
                    reconstruct_image_from_patches(
                        patches_mos_standard_deviation,
                        mask,
                        domain_image_is_mask=True), image.shape)

        mos /= len(dimensions_to_predict)
        mos_standard_deviation /= len(dimensions_to_predict)

        if mask is None:

            if which_model == "tidsQualityAssessment":
                return_dict = {
                    'MOS': mos,
                    'MOS.standardDeviation': mos_standard_deviation,
                    'MOS.mean': mos.mean(),
                    'MOS.standardDeviationMean': mos_standard_deviation.mean()
                }
                return (return_dict)

            elif is_koniq or which_model == 'user_defined':
                return_dict = {
                    'MOS': mos,
                    'sharpness': mos_standard_deviation,
                    'MOS.mean': mos.mean(),
                    'sharpness.mean': mos_standard_deviation.mean()
                }
                return (return_dict)

        else:

            if which_model == "tidsQualityAssessment":
                return_dict = {
                    'MOS':
                    mos,
                    'MOS.standardDeviation':
                    mos_standard_deviation,
                    'MOS.mean': (mos[mask >= 0.5]).mean(),
                    'MOS.standardDeviationMean':
                    (mos_standard_deviation[mask >= 0.5]).mean()
                }
                return (return_dict)

            elif is_koniq or which_model == 'user_defined':
                return_dict = {
                    'MOS': mos,
                    'sharpness': mos_standard_deviation,
                    'MOS.mean': (mos[mask >= 0.5]).mean(),
                    'sharpness.mean':
                    (mos_standard_deviation[mask >= 0.5]).mean()
                }
                return (return_dict)
예제 #12
0
 def test_abp_n4_example(self):
     img = ants.image_read(ants.get_ants_data("r16"))
     img = ants.iMath(img, "Normalize") * 255.0
     img2 = ants.abp_n4(img)
예제 #13
0
def data_augmentation(input_image_list,
                      segmentation_image_list=None,
                      pointset_list=None,
                      number_of_simulations=10,
                      reference_image=None,
                      transform_type='affineAndDeformation',
                      noise_model='additivegaussian',
                      noise_parameters=(0.0, 0.05),
                      sd_simulated_bias_field=0.05,
                      sd_histogram_warping=0.05,
                      sd_affine=0.05,
                      output_numpy_file_prefix=None,
                      verbose=False):
    """
    Randomly transform image data.

    Given an input image list (possibly multi-modal) and an optional corresponding
    segmentation image list, this function will perform data augmentation with
    the following augmentation possibilities:

    * spatial transformations
    * added image noise
    * simulated bias field
    * histogram warping

    Arguments
    ---------

    input_image_list : list of lists of ANTsImages
        List of lists of input images to warp.  The internal list sets contain one
        or more images (per subject) which are assumed to be mutually aligned.  The
        outer list contains multiple subject lists which are randomly sampled to
        produce output image list.

    segmentation_image_list : list of ANTsImages
        List of segmentation images corresponding to the input image list (optional).

    pointset_list: list of pointsets
        Numpy arrays corresponding to the input image list (optional).  If using this
        option, the transform_type must be invertible.

    number_of_simulations : integer
        Number of simulated output image sets.  Default = 10.

    reference_image : ANTsImage
        Defines the spatial domain for all output images.  If one is not specified,
        we used the first image in the input image list.

    transform_type : string
        One of the following options: "translation", "rigid", "scaleShear", "affine",
        "deformation", "affineAndDeformation".

    noise_model : string
        'additivegaussian', 'saltandpepper', 'shot', or 'speckle'.

    noise_parameters : tuple or array or float
        'additivegaussian': (mean, standardDeviation)
        'saltandpepper': (probability, saltValue, pepperValue)
        'shot': scale
        'speckle': standardDeviation
        Note that the standard deviation, scale, and probability values are *max* values
        and are randomly selected in the range [0, noise_parameter].  Also, the "mean",
        "saltValue" and "pepperValue" are assumed to be in the intensity normalized range
        of [0, 1].

    sd_simulated_bias_field : float
        Characterize the standard deviation of the amplitude.

    sd_histogram_warping : float
        Determines the strength of the bias field.

    sd_affine : float
        Determines the amount of transformation based change.

    output_numpy_file_prefix : string
        Filename of output numpy array containing all the simulated images and segmentations.

    Returns
    -------
    list of lists of transformed images and/or outputs to a numpy array.

    Example
    -------
    >>> image1_list = list()
    >>> image1_list.append(ants.image_read(ants.get_ants_data("r16")))
    >>> image2_list = list()
    >>> image2_list.append(ants.image_read(ants.get_ants_data("r64")))
    >>> segmentation1 = ants.threshold_image(image1_list[0], "Otsu", 3)
    >>> segmentation2 = ants.threshold_image(image2_list[0], "Otsu", 3)
    >>> input_segmentations = list()
    >>> input_segmentations.append(segmentation1)
    >>> input_segmentations.append(segmentation2)
    >>> points1 = ants.get_centroids(segmentation1)[:,0:2]
    >>> points2 = ants.get_centroids(segmentation2)[:,0:2]
    >>> input_points = list()
    >>> input_points.append(points1)
    >>> input_points.append(points2)
    >>> input_images = list()
    >>> input_images.append(image1_list)
    >>> input_images.append(image2_list)
    >>> data = data_augmentation(input_images,
                                 input_segmentations,
                                 input_points,
                                 tranform_type="scaleShear")
    """

    from ..utilities import histogram_warp_image_intensities
    from ..utilities import simulate_bias_field
    from ..utilities import randomly_transform_image_data

    if reference_image is None:
        reference_image = input_image_list[0][0]

    number_of_modalities = len(input_image_list[0])

    # Set up numpy arrays if outputing to file.

    batch_X = None
    batch_Y = None
    batch_Y_points = None
    number_of_points = 0

    if pointset_list is not None:
        number_of_points = pointset_list[0].shape[0]
        batch_Y_points = np.zeros((number_of_simulations, number_of_points,
                                   reference_image.dimension))

    if output_numpy_file_prefix is not None:
        batch_X = np.zeros((number_of_simulations, *reference_image.shape,
                            number_of_modalities))
        if segmentation_image_list is not None:
            batch_Y = np.zeros((number_of_simulations, *reference_image.shape))

    # Spatially transform input image data

    if verbose:
        print("Randomly spatially transforming the image data.")

    transform_augmentation = randomly_transform_image_data(
        reference_image,
        input_image_list=input_image_list,
        segmentation_image_list=segmentation_image_list,
        number_of_simulations=number_of_simulations,
        transform_type=transform_type,
        sd_affine=sd_affine,
        deformation_transform_type="bspline",
        number_of_random_points=1000,
        sd_noise=2.0,
        number_of_fitting_levels=4,
        mesh_size=1,
        sd_smoothing=4.0,
        input_image_interpolator='linear',
        segmentation_image_interpolator='nearestNeighbor')

    simulated_image_list = list()
    simulated_segmentation_image_list = list()
    simulated_pointset_list = list()

    for i in range(number_of_simulations):

        if verbose:
            print("Processing simulation " + str(i))

        segmentation = None
        if segmentation_image_list is not None:
            segmentation = transform_augmentation[
                'simulated_segmentation_images'][i]
            simulated_segmentation_image_list.append(segmentation)
            if batch_Y is not None:
                if reference_image.dimension == 2:
                    batch_Y[i, :, :] = segmentation.numpy()
                else:
                    batch_Y[i, :, :, :] = segmentation.numpy()

        if pointset_list is not None:
            simulated_transform = transform_augmentation[
                'simulated_transforms'][i]
            simulated_transform_inverse = ants.invert_ants_transform(
                simulated_transform)
            which_subject = transform_augmentation['which_subject'][i]
            simulated_points = np.zeros(
                (number_of_points, reference_image.dimension))
            for j in range(number_of_points):
                simulated_points[j, :] = ants.apply_ants_transform_to_point(
                    simulated_transform_inverse,
                    pointset_list[which_subject][j, :])
            simulated_pointset_list.append(simulated_points)
            if batch_Y_points is not None:
                batch_Y_points[i, :, :] = simulated_points

        simulated_local_image_list = list()
        for j in range(number_of_modalities):

            if verbose:
                print("    Modality " + str(j))

            image = transform_augmentation['simulated_images'][i][j]
            image_range = image.range()

            # Normalize to [0, 1] before applying augmentation

            if verbose:
                print("        Normalizing to [0, 1].")

            image = ants.iMath(image, "Normalize")

            # Noise

            if noise_model is not None:

                if verbose:
                    print("        Adding noise (" + noise_model + ").")

                if any(np.array(noise_parameters) > 0):

                    if noise_model.lower() == "additivegaussian":
                        parameters = (noise_parameters[0],
                                      random.uniform(0.0, noise_parameters[1]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="additivegaussian",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "saltandpepper":
                        parameters = (random.uniform(0.0, noise_parameters[0]),
                                      noise_parameters[1], noise_parameters[2])
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="saltandpepper",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "shot":
                        parameters = (random.uniform(0.0, noise_parameters[0]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="shot",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "speckle":
                        parameters = (random.uniform(0.0, noise_parameters[0]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="speckle",
                            noise_parameters=parameters)
                    else:
                        raise ValueError("Unrecognized noise model.")

            # Simulated bias field

            if sd_simulated_bias_field > 0:

                if verbose:
                    print("        Adding simulated bias field.")

                bias_field = simulate_bias_field(
                    image, sd_bias_field=sd_simulated_bias_field)
                image = image * (bias_field + 1)

            # Histogram intensity warping

            if sd_histogram_warping > 0:

                if verbose:
                    print("        Performing intensity histogram warping.")

                break_points = [0.2, 0.4, 0.6, 0.8]
                displacements = list()
                for b in range(len(break_points)):
                    displacements.append(random.gauss(0, sd_histogram_warping))
                image = histogram_warp_image_intensities(
                    image,
                    break_points=break_points,
                    clamp_end_points=(False, False),
                    displacements=displacements)

            # Rescale to original intensity range

            if verbose:
                print("        Rescaling to original intensity range.")

            image = ants.iMath(image, "Normalize") * (
                image_range[1] - image_range[0]) + image_range[0]

            simulated_local_image_list.append(image)

            if batch_X is not None:
                if reference_image.dimension == 2:
                    batch_X[i, :, :, j] = image.numpy()
                else:
                    batch_X[i, :, :, :, j] = image.numpy()

        simulated_image_list.append(simulated_local_image_list)

    if batch_X is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing images to numpy array.")
            np.save(output_numpy_file_prefix + "SimulatedImages.npy", batch_X)
    if batch_Y is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing segmentation images to numpy array.")
            np.save(
                output_numpy_file_prefix + "SimulatedSegmentationImages.npy",
                batch_Y)
    if batch_Y_points is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing segmentation images to numpy array.")
            np.save(output_numpy_file_prefix + "SimulatedPointsets.npy",
                    batch_Y_points)

    if segmentation_image_list is None and pointset_list is None:
        return ({'simulated_images': simulated_image_list})
    elif segmentation_image_list is None:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_pointset_list': simulated_pointset_list
        })
    elif pointset_list is None:
        return ({
            'simulated_images':
            simulated_image_list,
            'simulated_segmentation_images':
            simulated_segmentation_image_list
        })
    else:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_segmentation_images': simulated_segmentation_image_list,
            'simulated_pointset_list': simulated_pointset_list
        })
center_of_mass_image = ants.get_center_of_mass(image)
translation = np.asarray(center_of_mass_image) - np.asarray(
    center_of_mass_template)
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
                                  center=np.asarray(center_of_mass_template),
                                  translation=translation)
warped_image = ants.apply_ants_transform_to_image(xfrm,
                                                  image,
                                                  reorient_template,
                                                  interpolation='linear')
warped_mask = ants.apply_ants_transform_to_image(xfrm,
                                                 mask,
                                                 reorient_template,
                                                 interpolation='linear')
warped_mask = ants.threshold_image(warped_mask, 0.4999, 1.0001, 1, 0)
warped_mask = ants.iMath(warped_mask, "MD", 3)
warped_cropped_image = ants.crop_image(warped_image, warped_mask, 1)
original_cropped_size = warped_cropped_image.shape
warped_cropped_image = ants.resample_image(warped_cropped_image,
                                           resampled_image_size,
                                           use_voxels=True)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

batchX = np.expand_dims(warped_cropped_image.numpy(), axis=0)
batchX = np.expand_dims(batchX, axis=-1)
batchX = (batchX - batchX.mean()) / batchX.std()

print("Prediction and decoding")
start_time = time.time()
예제 #15
0
def brain_extraction(image,
                     modality,
                     antsxnet_cache_directory=None,
                     verbose=False):
    """
    Perform brain extraction using U-net and ANTs-based training data.  "NoBrainer"
    is also possible where brain extraction uses U-net and FreeSurfer training data
    ported from the

    https://github.com/neuronets/nobrainer-models

    Arguments
    ---------
    image : ANTsImage
        input image (or list of images for multi-modal scenarios).

    modality : string
        Modality image type.  Options include:
            * "t1": T1-weighted MRI---ANTs-trained.  Previous versions are specified as "t1.v0", "t1.v1".
            * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk.
            * "t1combined": Brian's combination of "t1" and "t1nobrainer".  One can also specify
                            "t1combined[X]" where X is the morphological radius.  X = 12 by default.
            * "flair": FLAIR MRI.   Previous versions are specified as "flair.v0".
            * "t2": T2 MRI.  Previous versions are specified as "t2.v0".
            * "t2star": T2Star MRI.
            * "bold": 3-D mean BOLD MRI.  Previous versions are specified as "bold.v0".
            * "fa": fractional anisotropy.  Previous versions are specified as "fa.v0".
            * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner.
            * "t1infant": T1-w infant MRI h/t Martin Styner.
            * "t2infant": T2-w infant MRI h/t Martin Styner.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs probability brain mask image.

    Example
    -------
    >>> probability_brain_mask = brain_extraction(brain_image, modality="t1")
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..architectures import create_nobrainer_unet_model_3d
    from ..utilities import decode_unet

    classes = ("background", "brain")
    number_of_classification_labels = len(classes)

    channel_size = 1
    if isinstance(image, list):
        channel_size = len(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    input_images = list()
    if channel_size == 1:
        input_images.append(image)
    else:
        input_images = image

    if input_images[0].dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if "t1combined" in modality:
        # Need to change with voxel resolution
        morphological_radius = 12
        if '[' in modality and ']' in modality:
            morphological_radius = int(modality.split("[")[1].split("]")[0])

        brain_extraction_t1 = brain_extraction(
            image,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_mask = ants.iMath_get_largest_component(
            ants.threshold_image(brain_extraction_t1, 0.5, 10000))
        brain_mask = ants.morphology(brain_mask, "close",
                                     morphological_radius).iMath_fill_holes()

        brain_extraction_t1nobrainer = brain_extraction(
            image * ants.iMath_MD(brain_mask, radius=morphological_radius),
            modality="t1nobrainer",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_extraction_combined = ants.iMath_fill_holes(
            ants.iMath_get_largest_component(brain_extraction_t1nobrainer *
                                             brain_mask))

        brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(
            brain_mask, morphological_radius) + brain_mask

        return (brain_extraction_combined)

    if modality != "t1nobrainer":

        #####################
        #
        # ANTs-based
        #
        #####################

        weights_file_name_prefix = None
        is_standard_network = False

        if modality == "t1.v0":
            weights_file_name_prefix = "brainExtraction"
        elif modality == "t1.v1":
            weights_file_name_prefix = "brainExtractionT1v1"
            is_standard_network = True
        elif modality == "t1":
            weights_file_name_prefix = "brainExtractionRobustT1"
            is_standard_network = True
        elif modality == "t2.v0":
            weights_file_name_prefix = "brainExtractionT2"
        elif modality == "t2":
            weights_file_name_prefix = "brainExtractionRobustT2"
            is_standard_network = True
        elif modality == "t2star":
            weights_file_name_prefix = "brainExtractionRobustT2Star"
            is_standard_network = True
        elif modality == "flair.v0":
            weights_file_name_prefix = "brainExtractionFLAIR"
        elif modality == "flair":
            weights_file_name_prefix = "brainExtractionRobustFLAIR"
            is_standard_network = True
        elif modality == "bold.v0":
            weights_file_name_prefix = "brainExtractionBOLD"
        elif modality == "bold":
            weights_file_name_prefix = "brainExtractionRobustBOLD"
            is_standard_network = True
        elif modality == "fa.v0":
            weights_file_name_prefix = "brainExtractionFA"
        elif modality == "fa":
            weights_file_name_prefix = "brainExtractionRobustFA"
            is_standard_network = True
        elif modality == "t1t2infant":
            weights_file_name_prefix = "brainExtractionInfantT1T2"
        elif modality == "t1infant":
            weights_file_name_prefix = "brainExtractionInfantT1"
        elif modality == "t2infant":
            weights_file_name_prefix = "brainExtractionInfantT2"
        else:
            raise ValueError("Unknown modality type.")

        if verbose == True:
            print("Brain extraction:  retrieving model weights.")

        weights_file_name = get_pretrained_network(
            weights_file_name_prefix,
            antsxnet_cache_directory=antsxnet_cache_directory)

        if verbose == True:
            print("Brain extraction:  retrieving template.")

        reorient_template_file_name_path = get_antsxnet_data(
            "S_template3", antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)
        if is_standard_network and modality != "t1.v1":
            ants.set_spacing(reorient_template, (1.5, 1.5, 1.5))
        resampled_image_size = reorient_template.shape

        number_of_filters = (8, 16, 32, 64)
        mode = "classification"
        if is_standard_network:
            number_of_filters = (16, 32, 64, 128)
            number_of_classification_labels = 1
            mode = "sigmoid"

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            mode=mode,
            number_of_filters=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=3,
            deconvolution_kernel_size=2,
            weight_decay=1e-5)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Brain extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template)
        center_of_mass_image = ants.get_center_of_mass(input_images[0])
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)

        batchX = np.zeros((1, *resampled_image_size, channel_size))

        for i in range(len(input_images)):
            warped_image = ants.apply_ants_transform_to_image(
                xfrm, input_images[i], reorient_template)
            if is_standard_network and modality != "t1.v1":
                batchX[0, :, :, :, i] = (ants.iMath(warped_image,
                                                    "Normalize")).numpy()
            else:
                warped_array = warped_image.numpy()
                batchX[0, :, :, :,
                       i] = (warped_array -
                             warped_array.mean()) / warped_array.std()

        if verbose == True:
            print("Brain extraction:  prediction and decoding.")

        predicted_data = unet_model.predict(batchX, verbose=verbose)
        probability_images_array = decode_unet(predicted_data,
                                               reorient_template)

        if verbose == True:
            print(
                "Brain extraction:  renormalize probability mask to native space."
            )

        xfrm_inv = xfrm.invert()
        probability_image = xfrm_inv.apply_to_image(
            probability_images_array[0][number_of_classification_labels - 1],
            input_images[0])

        return (probability_image)

    else:

        #####################
        #
        # NoBrainer
        #
        #####################

        if verbose == True:
            print("NoBrainer:  generating network.")

        model = create_nobrainer_unet_model_3d((None, None, None, 1))

        weights_file_name = get_pretrained_network(
            "brainExtractionNoBrainer",
            antsxnet_cache_directory=antsxnet_cache_directory)
        model.load_weights(weights_file_name)

        if verbose == True:
            print(
                "NoBrainer:  preprocessing (intensity truncation and resampling)."
            )

        image_array = image.numpy()
        image_robust_range = np.quantile(
            image_array[np.where(image_array != 0)], (0.02, 0.98))
        threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]
                                  ) + image_robust_range[0]

        thresholded_mask = ants.threshold_image(image, -10000, threshold_value,
                                                0, 1)
        thresholded_image = image * thresholded_mask

        image_resampled = ants.resample_image(thresholded_image,
                                              (256, 256, 256),
                                              use_voxels=True)
        image_array = np.expand_dims(image_resampled.numpy(), axis=0)
        image_array = np.expand_dims(image_array, axis=-1)

        if verbose == True:
            print("NoBrainer:  predicting mask.")

        brain_mask_array = np.squeeze(
            model.predict(image_array, verbose=verbose))
        brain_mask_resampled = ants.copy_image_info(
            image_resampled, ants.from_numpy(brain_mask_array))
        brain_mask_image = ants.resample_image(brain_mask_resampled,
                                               image.shape,
                                               use_voxels=True,
                                               interp_type=1)

        spacing = ants.get_spacing(image)
        spacing_product = spacing[0] * spacing[1] * spacing[2]
        minimum_brain_volume = round(649933.7 / spacing_product)
        brain_mask_labeled = ants.label_clusters(brain_mask_image,
                                                 minimum_brain_volume)

        return (brain_mask_labeled)