예제 #1
0
    def test_matrix_to_images(self):
        # def matrix_to_images(data_matrix, mask):
        for img in self.imgs:
            imgmask = ants.image_clone(  img > img.mean(), pixeltype = 'float' )
            data = img[imgmask]
            dataflat = data.reshape(1,-1)
            mat = np.vstack([dataflat,dataflat]).astype('float32')
            imglist = ants.matrix_to_images(mat, imgmask)
            nptest.assert_allclose((img*imgmask).numpy(), imglist[0].numpy())
            nptest.assert_allclose((img*imgmask).numpy(), imglist[1].numpy())
            self.assertTrue(ants.image_physical_space_consistency(img,imglist[0]))
            self.assertTrue(ants.image_physical_space_consistency(img,imglist[1]))

            # go back to matrix
            mat2 = ants.images_to_matrix(imglist, imgmask)
            nptest.assert_allclose(mat, mat2)

            # test with matrix.ndim > 2
            img = img.clone()
            img.set_direction(img.direction*2)
            imgmask = ants.image_clone(  img > img.mean(), pixeltype = 'float' )
            arr = (img*imgmask).numpy()
            arr = arr[arr>=0.5]
            arr2 = arr.copy()
            mat = np.stack([arr,arr2])
            imglist = ants.matrix_to_images(mat, imgmask)
            for im in imglist:
                self.assertTrue(ants.allclose(im, imgmask*img))
                self.assertTrue(ants.image_physical_space_consistency(im, imgmask))
예제 #2
0
    def test_images_to_matrix(self):
        # def images_to_matrix(image_list, mask=None, sigma=None, epsilon=0):
        for img in self.imgs:
            mask = ants.image_clone(  img > img.mean(), pixeltype = 'float' )
            imglist = [img.clone(),img.clone(),img.clone()]
            imgmat = ants.images_to_matrix(imglist, mask=mask)
            self.assertTrue(imgmat.shape[0] == len(imglist))
            self.assertTrue(imgmat.shape[1] == (mask>0).sum())

            # go back to images
            imglist2 = ants.matrix_to_images(imgmat, mask)
            for i1,i2 in zip(imglist,imglist2):
                self.assertTrue(ants.image_physical_space_consistency(i1,i2))
                nptest.assert_allclose(i1.numpy()*mask.numpy(),i2.numpy())

            if img.dimension == 2:
                # with sigma
                mask = ants.image_clone(  img > img.mean(), pixeltype = 'float' )
                imglist = [img.clone(),img.clone(),img.clone()]
                imgmat = ants.images_to_matrix(imglist, mask=mask, sigma=2.)

                # with no mask
                mask = ants.image_clone(  img > img.mean(), pixeltype = 'float' )
                imglist = [img.clone(),img.clone(),img.clone()]
                imgmat = ants.images_to_matrix(imglist)

                # with mask of different shape
                s = [65]*img.dimension
                mask2 = ants.from_numpy(np.random.randn(*s))
                mask2 = mask2 > mask2.mean()
                imgmat = ants.images_to_matrix(imglist, mask=mask2)
예제 #3
0
def cortical_thickness(t1,
                       antsxnet_cache_directory=None,
                       verbose=False):

    """
    Perform KellyKapowski cortical thickness using deep_atropos for
    segmentation.  Description concerning implementaiton and evaluation:

    https://www.medrxiv.org/content/10.1101/2020.10.19.20215392v1

    Arguments
    ---------
    t1 : ANTsImage
        input 3-D unprocessed T1-weighted brain image.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Cortical thickness image and segmentation probability images.

    Example
    -------
    >>> image = ants.image_read("t1w_image.nii.gz")
    >>> kk = cortical_thickness(image)
    """

    from ..utilities import deep_atropos

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    atropos = deep_atropos(t1, do_preprocessing=True,
        antsxnet_cache_directory=antsxnet_cache_directory, verbose=True)

    # Kelly Kapowski cortical thickness

    kk_segmentation = ants.image_clone(atropos['segmentation_image'])
    kk_segmentation[kk_segmentation == 4] = 3
    gray_matter = atropos['probability_images'][2]
    white_matter = (atropos['probability_images'][3] + atropos['probability_images'][4])
    kk = ants.kelly_kapowski(s=kk_segmentation, g=gray_matter, w=white_matter,
                            its=45, r=0.025, m=1.5, x=0, verbose=int(verbose))

    return_dict = {'thickness_image' : kk,
                   'segmentation_image' : atropos['segmentation_image'],
                   'csf_probability_image' : atropos['probability_images'][1],
                   'gray_matter_probability_image' : atropos['probability_images'][2],
                   'white_matter_probability_image' : atropos['probability_images'][3],
                   'deep_gray_matter_probability_image' : atropos['probability_images'][4],
                   'brain_stem_probability_image' : atropos['probability_images'][5],
                   'cerebellum_probability_image' : atropos['probability_images'][6]
                  }
    return(return_dict)
예제 #4
0
 def test_n4_bias_field_correction(self):
     #def n4_bias_field_correction(image, mask=None, shrink_factor=4,
     #                     convergence={'iters':[50,50,50,50], 'tol':1e-07},
     #                     spline_param=200, verbose=False, weight_mask=None):
     for img in self.imgs:
         image_n4 = ants.n4_bias_field_correction(img)
         mask = ants.image_clone(img > img.mean(), pixeltype='float')
         image_n4 = ants.n4_bias_field_correction(img, mask=mask)
예제 #5
0
def mi_img(y_true, y_pred, mask=None, metric_type="MattesMutualInformation"):
    """Compute the mutual information (MI) between two images.

    Parameters
    ----------
    y_true : np.array
        Image 1. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    y_pred : np.array
        Image 2. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    mask: np.array, optional
        Optional, can be specified to have the computation carried out on a precise area.

    metric_type: str, {'MattesMutualInformation', 'JointHistogramMutualInformation'}
        Type of mutual information computation.

    Returns
    -------
    mi : float
        The mutual information (MI) metric. Similarity metric, the higher the more similar the images are.

    """
    y_true_ants = ants.image_clone(ants.from_numpy(y_true), pixeltype="float")
    y_pred_ants = ants.image_clone(ants.from_numpy(y_pred), pixeltype="float")

    if mask is None:
        mi = ants.image_similarity(y_true_ants,
                                   y_pred_ants,
                                   metric_type=metric_type)

    else:
        mask_ants = ants.image_clone(ants.from_numpy(mask.astype(float)),
                                     pixeltype="float")
        mi = ants.image_similarity(
            y_true_ants,
            y_pred_ants,
            fixed_mask=mask_ants,
            moving_mask=mask_ants,
            metric_type=metric_type,
        )

    return -mi
예제 #6
0
def cross_correlation_img(y_true, y_pred, mask=None):
    """Compute the cross correlation metric between two images.

    Parameters
    ----------
    y_true : np.array
        Image 1. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    y_pred : np.array
        Image 2. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    mask: np.array, optional
        Optional, can be specified to have the computation carried out on a precise area.

    Returns
    -------
    cc : float
        The Cross-Correlation value. Similarity metric, the higher the more similar the images are.

    """
    y_true_ants = ants.image_clone(ants.from_numpy(y_true), pixeltype="float")
    y_pred_ants = ants.image_clone(ants.from_numpy(y_pred), pixeltype="float")

    if mask is None:
        cc = ants.image_similarity(y_true_ants,
                                   y_pred_ants,
                                   metric_type="Correlation")

    else:
        mask_ants = ants.image_clone(ants.from_numpy(mask.astype(float)),
                                     pixeltype="float")
        cc = ants.image_similarity(
            y_true_ants,
            y_pred_ants,
            fixed_mask=mask_ants,
            moving_mask=mask_ants,
            metric_type="Correlation",
        )

    return -cc
예제 #7
0
    def test_image_clone(self):
        for img in self.imgs:
            img = ants.image_clone(img, 'unsigned char')
            orig_ptype = img.pixeltype
            for ptype in self.pixeltypes:
                imgcloned = ants.image_clone(img, ptype)
                self.assertTrue(ants.image_physical_space_consistency(img,imgcloned))
                nptest.assert_allclose(img.numpy(), imgcloned.numpy())
                self.assertEqual(imgcloned.pixeltype, ptype)
                self.assertEqual(img.pixeltype, orig_ptype)

        for img in self.vecimgs:
            img = img.clone('unsigned char')
            orig_ptype = img.pixeltype
            for ptype in self.pixeltypes:
                imgcloned = ants.image_clone(img, ptype)
                self.assertTrue(ants.image_physical_space_consistency(img,imgcloned))
                self.assertEqual(imgcloned.components, img.components)
                nptest.assert_allclose(img.numpy(), imgcloned.numpy())
                self.assertEqual(imgcloned.pixeltype, ptype)
                self.assertEqual(img.pixeltype, orig_ptype)
예제 #8
0
    def test_make_image(self):
        self.setUp()

        for arr in self.arrs:
            voxval = 6.
            img = ants.make_image(arr.shape, voxval=voxval)
            self.assertTrue(img.dimension, arr.ndim)
            self.assertTrue(img.shape, arr.shape)
            nptest.assert_allclose(img.mean(), voxval)

            new_origin = tuple([6.9] * arr.ndim)
            new_spacing = tuple([3.6] * arr.ndim)
            new_direction = np.eye(arr.ndim) * 9.6
            img2 = ants.make_image(arr.shape,
                                   voxval=voxval,
                                   origin=new_origin,
                                   spacing=new_spacing,
                                   direction=new_direction)

            self.assertTrue(img2.dimension, arr.ndim)
            self.assertTrue(img2.shape, arr.shape)
            nptest.assert_allclose(img2.mean(), voxval)
            self.assertEqual(img2.origin, new_origin)
            self.assertEqual(img2.spacing, new_spacing)
            nptest.assert_allclose(img2.direction, new_direction)

            for ptype in self.pixeltypes:
                img = ants.make_image(arr.shape, voxval=1., pixeltype=ptype)
                self.assertEqual(img.pixeltype, ptype)

        # test with components
        img = ants.make_image((69, 70, 4), has_components=True)
        self.assertEqual(img.components, 4)
        self.assertEqual(img.dimension, 2)
        nptest.assert_allclose(img.mean(), 0.)

        img = ants.make_image((69, 70, 71, 4), has_components=True)
        self.assertEqual(img.components, 4)
        self.assertEqual(img.dimension, 3)
        nptest.assert_allclose(img.mean(), 0.)

        # set from image
        for img in self.imgs:
            mask = ants.image_clone(img > img.mean(), pixeltype='float')
            arr = img[mask]
            img2 = ants.make_image(mask, voxval=arr)
            nptest.assert_allclose(img2.numpy(), (img * mask).numpy())
            self.assertTrue(ants.image_physical_space_consistency(img2, mask))

            # set with arr.ndim > 1
            img2 = ants.make_image(mask, voxval=np.expand_dims(arr, -1))
            nptest.assert_allclose(img2.numpy(), (img * mask).numpy())
            self.assertTrue(ants.image_physical_space_consistency(img2, mask))
예제 #9
0
    def test_n4_bias_field_correction_example(self):
        image = ants.image_read(ants.get_ants_data('r16'))
        image_n4 = ants.n4_bias_field_correction(image)

        # spline param list
        image_n4 = ants.n4_bias_field_correction(image, spline_param=(10, 10))

        # image not float
        image_ui = image.clone('unsigned int')
        image_n4 = ants.n4_bias_field_correction(image_ui)

        # weight mask
        mask = ants.image_clone(image > image.mean(), pixeltype='float')
        ants.n4_bias_field_correction(image, weight_mask=mask)

        # len(spline_param) != img.dimension
        with self.assertRaises(Exception):
            ants.n4_bias_field_correction(image, spline_param=(10, 10, 10))

        # weight mask not ANTsImage
        with self.assertRaises(Exception):
            ants.n4_bias_field_correction(image, weight_mask=0.4)
예제 #10
0
def create_ROIs(path: str, extension: str = '.nii.gz'):
    """
    This ROI generator follows the structure FILE/FILE+extension
    and generates a ROI file with the name   FILE/FILE+_ROI+extension
    Parameters
    ----------
    path        Images Path

    Returns     Generated ROI images
    -------

    """

    for scan_id in os.listdir(path):
        print('Creating ROI for: ', scan_id)
        scan = ants.image_read(os.path.join(path, scan_id,
                                            scan_id + extension))
        brainmask = ants.image_clone(scan).apply(mask_image)
        brainmask.to_filename(
            os.path.join(path, scan_id, scan_id + '_ROI' + extension))

    return
def crop_image_center(image, crop_size):
    """
    Crop the center of an image.

    Arguments
    ---------
    image : ANTsImage
        Input image

    crop_size: n-D tuple (depending on dimensionality).
        Width, height, depth (if 3-D), and time (if 4-D) of crop region.

    Returns
    -------
    A list (or array) of patches.

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> cropped_image = crop_image_center(image, crop_size=(64, 64))
    """

    image_size = np.array(image.shape)

    if len(image_size) != len(crop_size):
        raise ValueError("crop_size does not match image size.")

    if (np.asarray(crop_size) > np.asarray(image_size)).any():
        raise ValueError("A crop_size dimension is larger than image_size.")

    start_index = (np.floor(
        0.5 * (np.asarray(image_size) - np.asarray(crop_size)))).astype(int)
    end_index = start_index + np.asarray(crop_size).astype(int)

    cropped_image = ants.crop_indices(
        ants.image_clone(image) * 1, start_index, end_index)

    return (cropped_image)
예제 #12
0
def preprocess_brain_image(image,
                           truncate_intensity=(0.01, 0.99),
                           brain_extraction_modality=None,
                           template_transform_type=None,
                           template="biobank",
                           do_bias_correction=True,
                           return_bias_field=False,
                           do_denoising=True,
                           intensity_matching_type=None,
                           reference_image=None,
                           intensity_normalization_type=None,
                           antsxnet_cache_directory=None,
                           verbose=True):

    """
    Basic preprocessing pipeline for T1-weighted brain MRI

    Standard preprocessing steps that have been previously described
    in various papers including the cortical thickness pipeline:

         https://www.ncbi.nlm.nih.gov/pubmed/24879923

    Arguments
    ---------
    image : ANTsImage
        input image

    truncate_intensity : 2-length tuple
        Defines the quantile threshold for truncating the image intensity

    brain_extraction_modality : string or None
        Perform brain extraction using antspynet tools.  One of "t1", "t1v0",
        "t1nobrainer", "t1combined", "flair", "t2", "bold", "fa", "t1infant",
        "t2infant", or None.

    template_transform_type : string
        See details in help for ants.registration.  Typically "Rigid" or
        "Affine".

    template : ANTs image (not skull-stripped)
        Alternatively, one can specify the default "biobank" or "croppedMni152"
        to download and use premade templates.

    do_bias_correction : boolean
        Perform N4 bias field correction.

    return_bias_field : boolean
        If True, return bias field as an additional output *without* bias
        correcting the preprocessed image.

    do_denoising : boolean
        Perform non-local means denoising.

    intensity_matching_type : string
        Either "regression" or "histogram". Only is performed if reference_image
        is not None.

    reference_image : ANTs image
        Reference image for intensity matching.

    intensity_normalization_type : string
        Either rescale the intensities to [0,1] (i.e., "01") or zero-mean, unit variance
        (i.e., "0mean").  If None normalization is not performed.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary with preprocessing information ANTs image (i.e., source_image) matched to the
    (reference_image).

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> preprocessed_image = preprocess_brain_image(image, do_brain_extraction=False)
    """

    from ..utilities import brain_extraction
    from ..utilities import regression_match_image
    from ..utilities import get_antsxnet_data

    preprocessed_image = ants.image_clone(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    # Truncate intensity
    if truncate_intensity is not None:
        quantiles = (image.quantile(truncate_intensity[0]), image.quantile(truncate_intensity[1]))
        if verbose == True:
            print("Preprocessing:  truncate intensities ( low =", quantiles[0], ", high =", quantiles[1], ").")

        preprocessed_image[image < quantiles[0]] = quantiles[0]
        preprocessed_image[image > quantiles[1]] = quantiles[1]

    # Brain extraction
    mask = None
    if brain_extraction_modality is not None:
        if verbose == True:
            print("Preprocessing:  brain extraction.")

        probability_mask = brain_extraction(preprocessed_image, modality=brain_extraction_modality,
            antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose)
        mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        mask = ants.morphology(mask,"close",6).iMath_fill_holes()

    # Template normalization
    transforms = None
    if template_transform_type is not None:
        template_image = None
        if isinstance(template, str):
            template_file_name_path = get_antsxnet_data(template, antsxnet_cache_directory=antsxnet_cache_directory)
            template_image = ants.image_read(template_file_name_path)
        else:
            template_image = template

        if mask is None:
            registration = ants.registration(fixed=template_image, moving=preprocessed_image,
                type_of_transform=template_transform_type, verbose=verbose)
            preprocessed_image = registration['warpedmovout']
            transforms = dict(fwdtransforms=registration['fwdtransforms'],
                              invtransforms=registration['invtransforms'])
        else:
            template_probability_mask = brain_extraction(template_image, modality=brain_extraction_modality, 
                antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose)
            template_mask = ants.threshold_image(template_probability_mask, 0.5, 1, 1, 0)
            template_brain_image = template_mask * template_image

            preprocessed_brain_image = preprocessed_image * mask

            registration = ants.registration(fixed=template_brain_image, moving=preprocessed_brain_image,
                type_of_transform=template_transform_type, verbose=verbose)
            transforms = dict(fwdtransforms=registration['fwdtransforms'],
                              invtransforms=registration['invtransforms'])

            preprocessed_image = ants.apply_transforms(fixed = template_image, moving = preprocessed_image,
                transformlist=registration['fwdtransforms'], interpolator="linear", verbose=verbose)
            mask = ants.apply_transforms(fixed = template_image, moving = mask,
                transformlist=registration['fwdtransforms'], interpolator="genericLabel", verbose=verbose)

    # Do bias correction
    bias_field = None
    if do_bias_correction == True:
        if verbose == True:
            print("Preprocessing:  brain correction.")
        n4_output = None
        if mask is None:
            n4_output = ants.n4_bias_field_correction(preprocessed_image, shrink_factor=4, return_bias_field=return_bias_field, verbose=verbose)
        else:
            n4_output = ants.n4_bias_field_correction(preprocessed_image, mask, shrink_factor=4, return_bias_field=return_bias_field, verbose=verbose)
        if return_bias_field == True:
            bias_field = n4_output
        else:
            preprocessed_image = n4_output

    # Denoising
    if do_denoising == True:
        if verbose == True:
            print("Preprocessing:  denoising.")

        if mask is None:
            preprocessed_image = ants.denoise_image(preprocessed_image, shrink_factor=1)
        else:
            preprocessed_image = ants.denoise_image(preprocessed_image, mask, shrink_factor=1)

    # Image matching
    if reference_image is not None and intensity_matching_type is not None:
        if verbose == True:
            print("Preprocessing:  intensity matching.")

        if intensity_matching_type == "regression":
            preprocessed_image = regression_match_image(preprocessed_image, reference_image)
        elif intensity_matching_type == "histogram":
            preprocessed_image = ants.histogram_match_image(preprocessed_image, reference_image)
        else:
            raise ValueError("Unrecognized intensity_matching_type.")

    # Intensity normalization
    if intensity_normalization_type is not None:
        if verbose == True:
            print("Preprocessing:  intensity normalization.")

        if intensity_normalization_type == "01":
            preprocessed_image = (preprocessed_image - preprocessed_image.min())/(preprocessed_image.max() - preprocessed_image.min())
        elif intensity_normalization_type == "0mean":
            preprocessed_image = (preprocessed_image - preprocessed_image.mean())/preprocessed_image.std()
        else:
            raise ValueError("Unrecognized intensity_normalization_type.")

    return_dict = {'preprocessed_image' : preprocessed_image}
    if mask is not None:
        return_dict['brain_mask'] = mask
    if bias_field is not None:
        return_dict['bias_field'] = bias_field
    if transforms is not None:
        return_dict['template_transforms'] = transforms

    return(return_dict)
예제 #13
0
def el_bicho(ventilation_image,
             mask,
             use_coarse_slices_only=True,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform functional lung segmentation using hyperpolarized gases.

    https://pubmed.ncbi.nlm.nih.gov/30195415/

    Arguments
    ---------
    ventilation_image : ANTsImage
        input ventilation image.

    mask : ANTsImage
        input mask.

    use_coarse_slices_only : boolean
        If True, apply network only in the dimension of greatest slice thickness.
        If False, apply to all dimensions and average the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Ventilation segmentation and corresponding probability images

    Example
    -------
    >>> image = ants.image_read("ventilation.nii.gz")
    >>> mask = ants.image_read("mask.nii.gz")
    >>> lung_seg = el_bicho(image, mask, use_coarse_slices=True, verbose=False)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size

    if ventilation_image.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if ventilation_image.shape != mask.shape:
        raise ValueError(
            "Ventilation image and mask size are not the same size.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess image
    #
    ################################

    template_size = (256, 256)
    classes = (0, 1, 2, 3, 4)
    number_of_classification_labels = len(classes)

    image_modalities = ("Ventilation", "Mask")
    channel_size = len(image_modalities)

    preprocessed_image = (ventilation_image -
                          ventilation_image.mean()) / ventilation_image.std()
    ants.set_direction(preprocessed_image, np.identity(3))

    mask_identity = ants.image_clone(mask)
    ants.set_direction(mask_identity, np.identity(3))

    ################################
    #
    # Build models and load weights
    #
    ################################

    unet_model = create_unet_model_2d(
        (*template_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        number_of_layers=4,
        number_of_filters_at_base_layer=32,
        dropout_rate=0.0,
        convolution_kernel_size=(3, 3),
        deconvolution_kernel_size=(2, 2),
        weight_decay=1e-5,
        additional_options=("attentionGating"))

    if verbose == True:
        print("El Bicho: retrieving model weights.")

    weights_file_name = get_pretrained_network(
        "elBicho", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    spacing = ants.get_spacing(preprocessed_image)
    dimensions_to_predict = (spacing.index(max(spacing)), )
    if use_coarse_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += preprocessed_image.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            ventilation_slice = pad_or_crop_image_to_size(
                ants.slice_image(preprocessed_image, dimensions_to_predict[d],
                                 i), template_size)
            batchX[slice_count, :, :, 0] = ventilation_slice.numpy()

            mask_slice = pad_or_crop_image_to_size(
                ants.slice_image(mask_identity, dimensions_to_predict[d], i),
                template_size)
            batchX[slice_count, :, :, 1] = mask_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_model.predict(batchX, verbose=verbose)

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    probability_images = list()
    for l in range(number_of_classification_labels):
        probability_images.append(ants.image_clone(mask) * 0)

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + preprocessed_image.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)

        for l in range(number_of_classification_labels):
            prediction_per_dimension = prediction[which_batch_slices, :, :, l]
            prediction_array = np.transpose(
                np.squeeze(prediction_per_dimension),
                permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(
                ventilation_image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                          ventilation_image.shape))
            probability_images[l] = probability_images[l] + (
                prediction_image - probability_images[l]) / (d + 1)

        current_start_slice = current_end_slice + 1

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], mask * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], mask * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), mask * 0 + 1)[0]

    return_dict = {
        'segmentation_image': segmentation_image,
        'probability_images': probability_images
    }
    return (return_dict)
예제 #14
0
def arterial_lesion_segmentation(image,
                                 antsxnet_cache_directory=None,
                                 verbose=False):
    """
    Perform arterial lesion segmentation using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Foreground probability image.

    Example
    -------
    >>> output = arterial_lesion_segmentation(histology_image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import get_pretrained_network

    if image.dimension != 2:
        raise ValueError("Image dimension must be 2.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    channel_size = 1

    weights_file_name = get_pretrained_network(
        "arterialLesionWeibinShi",
        antsxnet_cache_directory=antsxnet_cache_directory)

    resampled_image_size = (512, 512)

    unet_model = create_unet_model_2d(
        (*resampled_image_size, channel_size),
        number_of_outputs=1,
        mode="sigmoid",
        number_of_filters=(64, 96, 128, 256, 512),
        convolution_kernel_size=(3, 3),
        deconvolution_kernel_size=(2, 2),
        dropout_rate=0.0,
        weight_decay=0,
        additional_options=("initialConvolutionKernelSize[5]",
                            "attentionGating"))
    unet_model.load_weights(weights_file_name)

    if verbose == True:
        print("Preprocessing:  Resampling and N4 bias correction.")

    preprocessed_image = ants.image_clone(image)
    preprocessed_image = preprocessed_image / preprocessed_image.max()
    preprocessed_image = ants.resample_image(preprocessed_image,
                                             resampled_image_size,
                                             use_voxels=True,
                                             interp_type=0)
    mask = ants.image_clone(preprocessed_image) * 0 + 1
    preprocessed_image = ants.n4_bias_field_correction(preprocessed_image,
                                                       mask=mask,
                                                       shrink_factor=2,
                                                       return_bias_field=False,
                                                       verbose=verbose)

    batchX = np.expand_dims(preprocessed_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.min()) / (batchX.max() - batchX.min())

    predicted_data = unet_model.predict(batchX, verbose=int(verbose))

    origin = preprocessed_image.origin
    spacing = preprocessed_image.spacing
    direction = preprocessed_image.direction

    foreground_probability_image = ants.from_numpy(np.squeeze(
        predicted_data[0, :, :, 0]),
                                                   origin=origin,
                                                   spacing=spacing,
                                                   direction=direction)

    if verbose == True:
        print("Post-processing:  resampling to original space.")

    foreground_probability_image = ants.resample_image_to_target(
        foreground_probability_image, image)

    return (foreground_probability_image)
예제 #15
0
def antspy_registration(
    fixed_img,
    moving_img,
    registration_type="SyN",
    reg_iterations=(40, 20, 0),
    aff_metric="mattes",
    syn_metric="mattes",
    verbose=False,
    initial_transform=None,
    path=GLOBAL_CACHE_FOLDER,
):
    """Register images using ANTsPY.

    Parameters
    ----------
    fixed_img: np.ndarray
        Fixed image.

    moving_img: np.ndarray
        Moving image to register.

    registration_type: {'Translation', 'Rigid', 'Similarity', 'QuickRigid', 'DenseRigid', 'BOLDRigid', 'Affine',
                        'AffineFast', 'BOLDAffine', 'TRSAA', 'ElasticSyN', 'SyN', 'SyNRA', 'SyNOnly', 'SyNCC', 'SyNabp',
                        'SyNBold', 'SyNBoldAff', 'SyNAggro', 'TVMSQ', 'TVMSQC'}, default 'SyN'

        Optimization algorithm to use to register (more info: https://antspy.readthedocs.io/en/latest/registration.
        html?highlight=registration#ants.registration)

    reg_iterations: tuple, default (40, 20, 0)
        Vector of iterations for SyN.

    aff_metric: {'GC', 'mattes', 'meansquares'}, default 'mattes'
        The metric for the affine part.

    syn_metric: {'CC', 'mattes', 'meansquares', 'demons'}, default 'mattes'
        The metric for the SyN part.

    verbose : bool, default False
        If True, then the inner solver prints convergence related information in standard output.

    path : str
        Path to a folder to where to save the `.nii.gz` file representing the composite transform.

    initial_transform : list or None
        Transforms to prepend the before the registration.

    Returns
    -------
    df: DisplacementField
        Displacement field between the moving and the fixed image

    meta : dict
        Contains relevant images and paths.

    """
    path = str(path)
    path += "" if path[-1] == "/" else "/"

    fixed_ants_image = ants.image_clone(ants.from_numpy(fixed_img),
                                        pixeltype="float")
    moving_ants_image = ants.image_clone(ants.from_numpy(moving_img),
                                         pixeltype="float")
    meta = ants.registration(
        fixed_ants_image,
        moving_ants_image,
        registration_type,
        reg_iterations=reg_iterations,
        aff_metric=aff_metric,
        syn_metric=syn_metric,
        verbose=verbose,
        initial_transform=initial_transform,
        syn_sampling=32,
        aff_sampling=32,
    )

    filename = ants.apply_transforms(
        fixed_ants_image,
        moving_ants_image,
        meta["fwdtransforms"],
        compose=path + "final_transform",
    )

    df = nib.load(filename)
    data = df.get_fdata()
    data = data.squeeze()
    dx = data[:, :, 1]
    dy = data[:, :, 0]
    df_final = DisplacementField(dx, dy)

    return df_final, meta
예제 #16
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             which_model="sysu",
             which_axes=2,
             number_of_simulations=0,
             sd_affine=0.01,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform White matter hyperintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * intensity truncation,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{do_preprocessing = True}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction, intensity truncation, brain extraction.

    which_model : string
        one of:
            * "sysu" -- same as the original sysu network (without site specific preprocessing),
            * "sysu-ri" -- same as "sysu" but using ranked intensity scaling for input images,
            * "sysuWithAttention" -- "sysu" with attention gating,
            * "sysuWithAttentionAndSite" -- "sysu" with attention gating with site branch (see "sysuWithSite"),
            * "sysuPlus" -- "sysu" with attention gating and nn-Unet activation,
            * "sysuPlusSeg" -- "sysuPlus" with deep_atropos segmentation in an additional channel, and
            * "sysuWithSite" -- "sysu" with global pooling on encoding channels to predict "site".
            * "sysuPlusSegWithSite" -- "sysuPlusSeg" combined with "sysuWithSite"
        In addition to both modalities, all models have T1-only and flair-only variants except
        for "sysuPlusSeg" (which only has a T1-only variant) or "sysu-ri" (which has neither single
        modality variant).

    which_axes : string or scalar or tuple/vector
        apply 2-D model to 1 or more axes.  In addition to a scalar
        or vector, e.g., which_axes = (0, 2), one can use "max" for the
        axis with maximum anisotropy (default) or "all" for all axes.

    number_of_simulations : integer
        Number of random affine perturbations to transform the input.

    sd_affine : float
        Define the standard deviation of the affine transformation parameter.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import deep_atropos
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import randomly_transform_image_data
    from ..utilities import pad_or_crop_image_to_size

    do_t1_only = False
    do_flair_only = False

    if flair is None and t1 is not None:
        do_t1_only = True
    elif flair is not None and t1 is None:
        do_flair_only = True

    use_t1_segmentation = False
    if "Seg" in which_model:
        if do_flair_only:
            raise ValueError("Segmentation requires T1.")
        else:
            use_t1_segmentation = True

    if use_t1_segmentation and do_preprocessing == False:
        raise ValueError(
            "Using the t1 segmentation requires do_preprocessing=True.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    do_slicewise = True

    if do_slicewise == False:

        raise ValueError("Not available.")

        # ################################
        # #
        # # Preprocess images
        # #
        # ################################

        # t1_preprocessed = t1
        # t1_preprocessing = None
        # if do_preprocessing == True:
        #     t1_preprocessing = preprocess_brain_image(t1,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         template="croppedMni152",
        #         template_transform_type="antsRegistrationSyNQuickRepro[a]",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        # flair_preprocessed = flair
        # if do_preprocessing == True:
        #     flair_preprocessing = preprocess_brain_image(flair,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
        #         moving=flair_preprocessing["preprocessed_image"],
        #         transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
        #     flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        # ################################
        # #
        # # Build model and load weights
        # #
        # ################################

        # patch_size = (112, 112, 112)
        # stride_length = (t1_preprocessed.shape[0] - patch_size[0],
        #                 t1_preprocessed.shape[1] - patch_size[1],
        #                 t1_preprocessed.shape[2] - patch_size[2])

        # classes = ("background", "wmh" )
        # number_of_classification_labels = len(classes)
        # labels = (0, 1)

        # image_modalities = ("T1", "FLAIR")
        # channel_size = len(image_modalities)

        # unet_model = create_unet_model_3d((*patch_size, channel_size),
        #     number_of_outputs = number_of_classification_labels,
        #     number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
        #     convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
        #     weight_decay = 1e-5, additional_options=("attentionGating"))

        # weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
        #     antsxnet_cache_directory=antsxnet_cache_directory)
        # unet_model.load_weights(weights_file_name)

        # ################################
        # #
        # # Do prediction and normalize to native space
        # #
        # ################################

        # if verbose == True:
        #     print("ew_david:  prediction.")

        # batchX = np.zeros((8, *patch_size, channel_size))

        # t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        # t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,0] = t1_patches

        # flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        # flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,1] = flair_patches

        # predicted_data = unet_model.predict(batchX, verbose=verbose)

        # probability_images = list()
        # for i in range(len(labels)):
        #     print("Reconstructing image", classes[i])
        #     reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
        #         domain_image=t1_preprocessed, stride_length=stride_length)

        #     if do_preprocessing == True:
        #         probability_images.append(ants.apply_transforms(fixed=t1,
        #             moving=reconstructed_image,
        #             transformlist=t1_preprocessing['template_transforms']['invtransforms'],
        #             whichtoinvert=[True], interpolator="linear", verbose=verbose))
        #     else:
        #         probability_images.append(reconstructed_image)

        # return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        use_rank_intensity_scaling = False
        if "-ri" in which_model:
            use_rank_intensity_scaling = True

        t1_preprocessed = None
        t1_preprocessing = None
        brain_mask = None
        if t1 is not None:
            if do_preprocessing == True:
                t1_preprocessing = preprocess_brain_image(
                    t1,
                    truncate_intensity=(0.01, 0.995),
                    brain_extraction_modality="t1",
                    do_bias_correction=False,
                    do_denoising=False,
                    antsxnet_cache_directory=antsxnet_cache_directory,
                    verbose=verbose)
                brain_mask = ants.threshold_image(
                    t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                t1_preprocessed = t1_preprocessing["preprocessed_image"]

        t1_segmentation = None
        if use_t1_segmentation:
            atropos_seg = deep_atropos(t1,
                                       do_preprocessing=True,
                                       verbose=verbose)
            t1_segmentation = atropos_seg['segmentation_image']

        flair_preprocessed = None
        if flair is not None:
            flair_preprocessed = flair
            if do_preprocessing == True:
                if brain_mask is None:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=(0.01, 0.995),
                        brain_extraction_modality="flair",
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                    brain_mask = ants.threshold_image(
                        flair_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                else:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=None,
                        brain_extraction_modality=None,
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                flair_preprocessed = flair_preprocessing["preprocessed_image"]

        if t1_preprocessed is not None:
            t1_preprocessed = t1_preprocessed * brain_mask
        if flair_preprocessed is not None:
            flair_preprocessed = flair_preprocessed * brain_mask

        if t1_preprocessed is not None:
            resampling_params = list(ants.get_spacing(t1_preprocessed))
        else:
            resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            if flair_preprocessed is not None:
                flair_preprocessed = ants.resample_image(flair_preprocessed,
                                                         resampling_params,
                                                         use_voxels=False,
                                                         interp_type=0)
            if t1_preprocessed is not None:
                t1_preprocessed = ants.resample_image(t1_preprocessed,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=0)
            if t1_segmentation is not None:
                t1_segmentation = ants.resample_image(t1_segmentation,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=1)
            if brain_mask is not None:
                brain_mask = ants.resample_image(brain_mask,
                                                 resampling_params,
                                                 use_voxels=False,
                                                 interp_type=1)

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (208, 208)

        image_modalities = ("T1", "FLAIR")
        if do_flair_only:
            image_modalities = ("FLAIR", )
        elif do_t1_only:
            image_modalities = ("T1", )
        if use_t1_segmentation:
            image_modalities = (*image_modalities, "T1Seg")
        channel_size = len(image_modalities)

        unet_model = None
        if which_model == "sysu" or which_model == "sysu-ri":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuWithAttention":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithAttentionAndSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuPlusSegWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        else:
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = None
        if which_model == "sysu" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysu",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu-ri" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuRankedIntensity",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttention",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlus",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSeg",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            raise ValueError(
                "Incorrect model specification or image combination.")

        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Data augmentation and extract slices
        #
        ################################

        wmh_probability_image = None
        if t1 is not None:
            wmh_probability_image = ants.image_clone(t1_preprocessed) * 0
        else:
            wmh_probability_image = ants.image_clone(flair_preprocessed) * 0

        wmh_site = np.array([0, 0, 0])

        data_augmentation = None
        if number_of_simulations > 0:
            if do_flair_only:
                data_augmentation = randomly_transform_image_data(
                    reference_image=flair_preprocessed,
                    input_image_list=[[flair_preprocessed]],
                    number_of_simulations=number_of_simulations,
                    transform_type='affine',
                    sd_affine=sd_affine,
                    input_image_interpolator='linear')
            elif do_t1_only:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')
            else:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')

        dimensions_to_predict = list((0, ))
        if which_axes == "max":
            spacing = ants.get_spacing(wmh_probability_image)
            dimensions_to_predict = (spacing.index(max(spacing)), )
        elif which_axes == "all":
            dimensions_to_predict = list(range(3))
        else:
            if isinstance(which_axes, int):
                dimensions_to_predict = list((which_axes, ))
            else:
                dimensions_to_predict = list(which_axes)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += wmh_probability_image.shape[
                dimensions_to_predict[d]]

        batchX = np.zeros(
            (total_number_of_slices, *template_size, channel_size))

        for n in range(number_of_simulations + 1):

            batch_flair = flair_preprocessed
            batch_t1 = t1_preprocessed
            batch_t1_segmentation = t1_segmentation
            batch_brain_mask = brain_mask

            if n > 0:

                if do_flair_only:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")

                elif do_t1_only:
                    batch_t1 = data_augmentation['simulated_images'][n - 1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        t1_preprocessed,
                        interpolation="nearestneighbor")
                else:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_t1 = data_augmentation['simulated_images'][n - 1][1]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")
                if use_t1_segmentation:
                    batch_t1_segmentation = data_augmentation[
                        'simulated_segmentation_images'][n - 1]

            if use_rank_intensity_scaling:
                if batch_t1 is not None:
                    batch_t1 = ants.rank_intensity(batch_t1,
                                                   batch_brain_mask) - 0.5
                if batch_flair is not None:
                    batch_flair = ants.rank_intensity(flair_preprocessed,
                                                      batch_brain_mask) - 0.5
            else:
                if batch_t1 is not None:
                    batch_t1 = (batch_t1 -
                                batch_t1[batch_brain_mask == 1].mean()
                                ) / batch_t1[batch_brain_mask == 1].std()
                if batch_flair is not None:
                    batch_flair = (
                        batch_flair -
                        batch_flair[batch_brain_mask == 1].mean()
                    ) / batch_flair[batch_brain_mask == 1].std()

            slice_count = 0
            for d in range(len(dimensions_to_predict)):

                number_of_slices = None
                if batch_t1 is not None:
                    number_of_slices = batch_t1.shape[dimensions_to_predict[d]]
                else:
                    number_of_slices = batch_flair.shape[
                        dimensions_to_predict[d]]

                if verbose == True:
                    print("Extracting slices for dimension ",
                          dimensions_to_predict[d])

                for i in range(number_of_slices):

                    brain_mask_slice = pad_or_crop_image_to_size(
                        ants.slice_image(batch_brain_mask,
                                         dimensions_to_predict[d], i),
                        template_size)

                    channel_count = 0
                    if batch_flair is not None:
                        flair_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_flair,
                                             dimensions_to_predict[d], i),
                            template_size)
                        flair_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = flair_slice.numpy()
                        channel_count += 1
                    if batch_t1 is not None:
                        t1_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_slice.numpy()
                        channel_count += 1
                    if t1_segmentation is not None:
                        t1_segmentation_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1_segmentation,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_segmentation_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_segmentation_slice.numpy(
                               ) / 6 - 0.5

                    slice_count += 1

            ################################
            #
            # Do prediction and then restack into the image
            #
            ################################

            if verbose == True:
                if n == 0:
                    print("Prediction")
                else:
                    print("Prediction (simulation " + str(n) + ")")

            prediction = unet_model.predict(batchX, verbose=verbose)

            permutations = list()
            permutations.append((0, 1, 2))
            permutations.append((1, 0, 2))
            permutations.append((1, 2, 0))

            prediction_image_average = ants.image_clone(
                wmh_probability_image) * 0

            current_start_slice = 0
            for d in range(len(dimensions_to_predict)):
                current_end_slice = current_start_slice + wmh_probability_image.shape[
                    dimensions_to_predict[d]]
                which_batch_slices = range(current_start_slice,
                                           current_end_slice)
                if isinstance(prediction, list):
                    prediction_per_dimension = prediction[0][
                        which_batch_slices, :, :, 0]
                else:
                    prediction_per_dimension = prediction[
                        which_batch_slices, :, :, 0]
                prediction_array = np.transpose(
                    np.squeeze(prediction_per_dimension),
                    permutations[dimensions_to_predict[d]])
                prediction_image = ants.copy_image_info(
                    wmh_probability_image,
                    pad_or_crop_image_to_size(
                        ants.from_numpy(prediction_array),
                        wmh_probability_image.shape))
                prediction_image_average = prediction_image_average + (
                    prediction_image - prediction_image_average) / (d + 1)
                current_start_slice = current_end_slice

            wmh_probability_image = wmh_probability_image + (
                prediction_image_average - wmh_probability_image) / (n + 1)
            if isinstance(prediction, list):
                wmh_site = wmh_site + (np.mean(prediction[1], axis=0) -
                                       wmh_site) / (n + 1)

        if do_resampling:
            if t1 is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, t1)
            if flair is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, flair)

        if isinstance(prediction, list):
            return ([wmh_probability_image, wmh_site])
        else:
            return (wmh_probability_image)
예제 #17
0
def deep_flash(t1,
               do_preprocessing=True,
               output_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA3
    Label 14:  right DG/CA3
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    output_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        tempfile.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import crop_image_center
    from ..utilities import pad_or_crop_image_to_size

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            output_directory=output_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Build model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    if output_directory is not None:
        weights_file_name = output_directory + "/deepFlashWeights.h5"
        if not os.path.exists(weights_file_name):
            if verbose == True:
                print("Deep Flash:  downloading model weights.")
            weights_file_name = get_pretrained_network("deepFlashWeights",
                                                       weights_file_name)
    else:
        weights_file_name = get_pretrained_network("deepFlash")

    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = pad_or_crop_image_to_size(t1_preprocessed, template_size)

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    return_dict = {
        'segmentation_image': relabeled_image,
        'probability_images': probability_images
    }
    return (return_dict)
예제 #18
0
def deep_flash(t1,
               t2=None,
               do_preprocessing=True,
               use_rank_intensity=True,
               antsxnet_cache_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA2/CA3/CA4
    Label 14:  right DG/CA2/CA3/CA4
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * affine registration to the "deep flash" template.
    which is performed on the input images if do_preprocessing = True.

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    t2 : ANTsImage
        Optional 3-D T2-weighted brain image.  If specified, it is assumed to be
        pre-aligned to the t1.

    do_preprocessing : boolean
        See description above.

    use_rank_intensity : boolean
        If false, use histogram matching with cropped template ROI.  Otherwise,
        use a rank intensity transform on the cropped ROI.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label and foreground.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import brain_extraction

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Options temporarily taken from the user
    #
    ################################

    # use_hierarchical_parcellation : boolean
    #     If True, use u-net model with additional outputs of the medial temporal lobe
    #     region, hippocampal, and entorhinal/perirhinal/parahippocampal regions.  Otherwise
    #     the only additional output is the medial temporal lobe.
    #
    # use_contralaterality : boolean
    #     Use both hemispherical models to also predict the corresponding contralateral
    #     segmentation and use both sets of priors to produce the results.  Mainly used
    #     for debugging.

    use_hierarchical_parcellation = True
    use_contralaterality = True

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    t1_mask = None
    t1_preprocessed_flipped = None
    t1_template = ants.image_read(
        get_antsxnet_data("deepFlashTemplateT1SkullStripped"))
    template_transforms = None
    if do_preprocessing:

        if verbose == True:
            print("Preprocessing T1.")

        # Brain extraction
        probability_mask = brain_extraction(
            t1_preprocessed,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        t1_preprocessed = t1_preprocessed * t1_mask

        # Do bias correction
        t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed,
                                                        t1_mask,
                                                        shrink_factor=4,
                                                        verbose=verbose)

        # Warp to template
        registration = ants.registration(
            fixed=t1_template,
            moving=t1_preprocessed,
            type_of_transform="antsRegistrationSyNQuickRepro[a]",
            verbose=verbose)
        template_transforms = dict(fwdtransforms=registration['fwdtransforms'],
                                   invtransforms=registration['invtransforms'])
        t1_preprocessed = registration['warpedmovout']

    if use_contralaterality:
        t1_preprocessed_array = t1_preprocessed.numpy()
        t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0)
        t1_preprocessed_flipped = ants.from_numpy(
            t1_preprocessed_array_flipped,
            origin=t1_preprocessed.origin,
            spacing=t1_preprocessed.spacing,
            direction=t1_preprocessed.direction)

    t2_preprocessed = t2
    t2_preprocessed_flipped = None
    t2_template = None
    if t2 is not None:
        t2_template = ants.image_read(
            get_antsxnet_data("deepFlashTemplateT2SkullStripped"))
        t2_template = ants.copy_image_info(t1_template, t2_template)
        if do_preprocessing:

            if verbose == True:
                print("Preprocessing T2.")

            # Brain extraction
            t2_preprocessed = t2_preprocessed * t1_mask

            # Do bias correction
            t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed,
                                                            t1_mask,
                                                            shrink_factor=4,
                                                            verbose=verbose)

            # Warp to template
            t2_preprocessed = ants.apply_transforms(
                fixed=t1_template,
                moving=t2_preprocessed,
                transformlist=template_transforms['fwdtransforms'],
                verbose=verbose)

        if use_contralaterality:
            t2_preprocessed_array = t2_preprocessed.numpy()
            t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array,
                                                    axis=0)
            t2_preprocessed_flipped = ants.from_numpy(
                t2_preprocessed_array_flipped,
                origin=t2_preprocessed.origin,
                spacing=t2_preprocessed.spacing,
                direction=t2_preprocessed.direction)

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
    image_size = (64, 64, 96)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    ################################
    #
    # Download spatial priors
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)
    for i in range(len(priors_image_list)):
        priors_image_list[i] = ants.copy_image_info(t1_preprocessed,
                                                    priors_image_list[i])

    labels_left = labels[1::2]
    priors_image_left_list = priors_image_list[1::2]
    probability_images_left = list()
    foreground_probability_images_left = list()
    lower_bound_left = (76, 74, 56)
    upper_bound_left = (140, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                    upper_bound_left)
    origin_left = tmp_cropped.origin

    spacing = tmp_cropped.spacing
    direction = tmp_cropped.direction

    t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left,
                                             upper_bound_left)
    t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min(
    )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0
    t2_template_roi_left = None
    if t2_template is not None:
        t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left,
                                                 upper_bound_left)
        t2_template_roi_left = (t2_template_roi_left -
                                t2_template_roi_left.min()) / (
                                    t2_template_roi_left.max() -
                                    t2_template_roi_left.min()) * 2.0 - 1.0

    labels_right = labels[2::2]
    priors_image_right_list = priors_image_list[2::2]
    probability_images_right = list()
    foreground_probability_images_right = list()
    lower_bound_right = (20, 74, 56)
    upper_bound_right = (84, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                    upper_bound_right)
    origin_right = tmp_cropped.origin

    t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right,
                                              upper_bound_right)
    t1_template_roi_right = (
        t1_template_roi_right - t1_template_roi_right.min()
    ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0
    t2_template_roi_right = None
    if t2_template is not None:
        t2_template_roi_right = ants.crop_indices(t2_template,
                                                  lower_bound_right,
                                                  upper_bound_right)
        t2_template_roi_right = (t2_template_roi_right -
                                 t2_template_roi_right.min()) / (
                                     t2_template_roi_right.max() -
                                     t2_template_roi_right.min()) * 2.0 - 1.0

    ################################
    #
    # Create model
    #
    ################################

    channel_size = 1 + len(labels_left)
    if t2 is not None:
        channel_size += 1

    number_of_classification_labels = 1 + len(labels_left)

    unet_model = create_unet_model_3d(
        (*image_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        mode="classification",
        number_of_filters=(32, 64, 96, 128, 256),
        convolution_kernel_size=(3, 3, 3),
        deconvolution_kernel_size=(2, 2, 2),
        dropout_rate=0.0,
        weight_decay=0)

    penultimate_layer = unet_model.layers[-2].output

    # medial temporal lobe
    output1 = Conv3D(
        filters=1,
        kernel_size=(1, 1, 1),
        activation='sigmoid',
        kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

    if use_hierarchical_parcellation:

        # EC, perirhinal, and parahippo.
        output2 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        # Hippocampus
        output3 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        unet_model = Model(
            inputs=unet_model.input,
            outputs=[unet_model.output, output1, output2, output3])
    else:
        unet_model = Model(inputs=unet_model.input,
                           outputs=[unet_model.output, output1])

    ################################
    #
    # Left:  build model and load weights
    #
    ################################

    network_name = 'deepFlashLeftT1'
    if t2 is not None:
        network_name = 'deepFlashLeftBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (left).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Left:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (left).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                   upper_bound_left)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_left, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_left, upper_bound_left)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_left, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left,
                                       upper_bound_left)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_left, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_left, upper_bound_left)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_left, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_left_list)):
        cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                          lower_bound_left, upper_bound_left)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_left))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_left)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_left, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                probability_images_left.append(probability_image)
            else:  # flipped
                probability_images_right.append(probability_image)

    ################################
    #
    # Left:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_left, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                foreground_probability_images_left.append(probability_image)
            else:
                foreground_probability_images_right.append(probability_image)

    ################################
    #
    # Right:  build model and load weights
    #
    ################################

    network_name = 'deepFlashRightT1'
    if t2 is not None:
        network_name = 'deepFlashRightBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (right).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Right:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (right).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                   upper_bound_right)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_right, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_right, upper_bound_right)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_right, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right,
                                       upper_bound_right)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_right, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_right,
                                           upper_bound_right)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_right, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_right_list)):
        cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                          lower_bound_right, upper_bound_right)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_right))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_right)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_right, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    probability_images_right[i] = (
                        probability_images_right[i] + probability_image) / 2
                else:
                    probability_images_right.append(probability_image)
            else:  # flipped
                probability_images_left[i] = (probability_images_left[i] +
                                              probability_image) / 2

    ################################
    #
    # Right:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_right, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    foreground_probability_images_right[
                        i - 1] = (foreground_probability_images_right[i - 1] +
                                  probability_image) / 2
                else:
                    foreground_probability_images_right.append(
                        probability_image)
            else:
                foreground_probability_images_left[
                    i - 1] = (foreground_probability_images_left[i - 1] +
                              probability_image) / 2

    ################################
    #
    # Combine priors
    #
    ################################

    probability_background_image = ants.image_clone(t1) * 0
    for i in range(1, len(probability_images_left)):
        probability_background_image += probability_images_left[i]
    for i in range(1, len(probability_images_right)):
        probability_background_image += probability_images_right[i]

    probability_images.append(probability_background_image * -1 + 1)
    for i in range(1, len(probability_images_left)):
        probability_images.append(probability_images_left[i])
        probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    foreground_probability_images = list()
    for i in range(len(foreground_probability_images_left)):
        foreground_probability_images.append(
            foreground_probability_images_left[i] +
            foreground_probability_images_right[i])

    return_dict = None
    if use_hierarchical_parcellation:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0],
            'other_region_probability_image':
            foreground_probability_images[1],
            'hippocampal_probability_image':
            foreground_probability_images[2]
        }
    else:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0]
        }

    return (return_dict)
예제 #19
0
def deep_flash_deprecated(t1,
                          do_preprocessing=True,
                          do_per_hemisphere=True,
                          which_hemisphere_models="new",
                          antsxnet_cache_directory=None,
                          verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA3
    Label 14:  right DG/CA3
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    do_per_hemisphere : boolean
        If True, do prediction based on separate networks per hemisphere.  Otherwise,
        use the single network trained for both hemispheres.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    print("This function is deprecated.  Please update to deep_flash().")

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type="antsRegistrationSyNQuickRepro[a]",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)

    ################################
    #
    # Process left/right in same network
    #
    ################################

    if do_per_hemisphere == False:

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (160, 192, 160)

        unet_model = create_unet_model_3d(
            (*template_size, 1),
            number_of_outputs=len(labels),
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights.")

        weights_file_name = get_pretrained_network(
            "deepFlash", antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction.")

        cropped_image = pad_or_crop_image_to_size(t1_preprocessed,
                                                  template_size)

        batchX = np.expand_dims(cropped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        for i in range(len(labels)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images.append(decropped_image)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    else:

        ################################
        #
        # Left:  download spatial priors
        #
        ################################

        spatial_priors_left_file_name_path = get_antsxnet_data(
            "priorDeepFlashLeftLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_left = ants.image_read(
            spatial_priors_left_file_name_path)
        priors_image_left_list = ants.ndimage_to_list(spatial_priors_left)

        ################################
        #
        # Left:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_left = (0, 5, 7, 9, 11, 13, 15, 17)
        channel_size = 1 + len(labels_left)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashLeft16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashLeft16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_left),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        if verbose:
            print("DeepFlash: retrieving model weights (left).")
        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Left:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (left).")

        cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0),
                                          (94, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_left_list)):
            cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                              (30, 51, 0), (94, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_left = list()
        for i in range(len(labels_left)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_left.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_left.append(decropped_image)

        ################################
        #
        # Right:  download spatial priors
        #
        ################################

        spatial_priors_right_file_name_path = get_antsxnet_data(
            "priorDeepFlashRightLabels",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors_right = ants.image_read(
            spatial_priors_right_file_name_path)
        priors_image_right_list = ants.ndimage_to_list(spatial_priors_right)

        ################################
        #
        # Right:  build model and load weights
        #
        ################################

        template_size = (64, 96, 96)
        labels_right = (0, 6, 8, 10, 12, 14, 16, 18)
        channel_size = 1 + len(labels_right)

        number_of_filters = 16
        network_name = ''
        if which_hemisphere_models == "old":
            network_name = "deepFlashRight16"
        elif which_hemisphere_models == "new":
            network_name = "deepFlashRight16new"
        else:
            raise ValueError("network_name must be \"old\" or \"new\".")

        unet_model = create_unet_model_3d(
            (*template_size, channel_size),
            number_of_outputs=len(labels_right),
            number_of_layers=4,
            number_of_filters_at_base_layer=number_of_filters,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5,
            additional_options=("attentionGating", ))

        weights_file_name = get_pretrained_network(
            network_name, antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Right:  do prediction and normalize to native space
        #
        ################################

        if verbose:
            print("Prediction (right).")

        cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0),
                                          (152, 147, 96))
        image_array = cropped_image.numpy()
        image_array = (image_array - image_array.mean()) / image_array.std()

        batchX = np.zeros((1, *template_size, channel_size))
        batchX[0, :, :, :, 0] = image_array

        for i in range(len(priors_image_right_list)):
            cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                              (88, 51, 0), (152, 147, 96))
            batchX[0, :, :, :, i + 1] = cropped_prior.numpy()

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = cropped_image.origin
        spacing = cropped_image.spacing
        direction = cropped_image.direction

        probability_images_right = list()
        for i in range(len(labels_right)):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction)
            if i > 0:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0)
            else:
                decropped_image = ants.decrop_image(probability_image,
                                                    t1_preprocessed * 0 + 1)

            if do_preprocessing:
                probability_images_right.append(
                    ants.apply_transforms(
                        fixed=t1,
                        moving=decropped_image,
                        transformlist=t1_preprocessing['template_transforms']
                        ['invtransforms'],
                        whichtoinvert=[True],
                        interpolator="linear",
                        verbose=verbose))
            else:
                probability_images_right.append(decropped_image)

        ################################
        #
        # Combine priors
        #
        ################################

        probability_background_image = ants.image_clone(t1) * 0
        for i in range(1, len(probability_images_left)):
            probability_background_image += probability_images_left[i]
        for i in range(1, len(probability_images_right)):
            probability_background_image += probability_images_right[i]

        probability_images.append(probability_background_image * -1 + 1)
        for i in range(1, len(probability_images_left)):
            probability_images.append(probability_images_left[i])
            probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    return_dict = {
        'segmentation_image': relabeled_image,
        'probability_images': probability_images
    }
    return (return_dict)
예제 #20
0
def histogram_warp_image_intensities(image,
                                     break_points=(0.25, 0.5, 0.75),
                                     displacements=None,
                                     clamp_end_points=(False, False),
                                     sd_displacements=0.05,
                                     transform_domain_size=20):
    """
    Transform image intensities based on histogram mapping.

    Apply B-spline 1-D maps to an input image for intensity warping.

    Arguments
    ---------
    image : ANTsImage
        Input image.

    break_points : integer or tuple
        Parametric points at which the intensity transform displacements
        are specified between [0, 1].  Alternatively, a single number can
        be given and the sequence is linearly spaced in [0, 1].

    displacements : tuple
        displacements to define intensity warping.  Length must be equal to the
        breakPoints.  Alternatively, if None random displacements are chosen
        (random normal:  mean = 0, sd = sd_displacements).

    sd_displacements : float
        Characterize the randomness of the intensity displacement.

    clamp_end_points : 2-element tuple of booleans
        Specify non-zero intensity change at the ends of the histogram.

    transform_domain_size : integer
        Defines the sampling resolution of the B-spline warping.

    Returns
    -------
    ANTs image

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data("r64"))
    >>> transformed_image = histogram_warp_image_intensities( image )
    """

    if not len(clamp_end_points) == 2:
        raise ValueError(
            "clamp_end_points must be a boolean tuple of length 2.")

    if not isinstance(break_points, int):
        if any(b < 0 for b in break_points) and any(b > 1
                                                    for b in break_points):
            raise ValueError(
                "If specifying break_points as a vector, values must be in the range [0, 1]"
            )

    parametric_points = None
    number_of_nonzero_displacements = 1
    if not isinstance(break_points, int):
        parametric_points = break_points
        number_of_nonzero_displacements = len(break_points)
        if clamp_end_points[0] is True:
            parametric_points = (0, *parametric_points)
        if clamp_end_points[1] is True:
            parametric_points = (*parametric_points, 1)
    else:
        total_number_of_break_points = break_points
        if clamp_end_points[0] is True:
            total_number_of_break_points += 1
        if clamp_end_points[1] is True:
            total_number_of_break_points += 1
        parametric_points = np.linspace(0, 1, total_number_of_break_points)
        number_of_nonzero_displacements = break_points

    if displacements is None:
        displacements = np.random.normal(loc=0.0,
                                         scale=sd_displacements,
                                         size=number_of_nonzero_displacements)

    weights = np.ones(len(displacements))
    if clamp_end_points[0] is True:
        displacements = (0, *displacements)
        weights = np.concatenate((1000 * np.ones(1), weights))
    if clamp_end_points[1] is True:
        displacements = (*displacements, 0)
        weights = np.concatenate((weights, 1000 * np.ones(1)))

    if not len(displacements) == len(parametric_points):
        raise ValueError(
            "Length of displacements does not match the length of the break points."
        )

    scattered_data = np.reshape(displacements, (len(displacements), 1))
    parametric_data = np.reshape(parametric_points,
                                 (len(parametric_points), 1))

    transform_domain_origin = 0
    transform_domain_spacing = (1.0 - transform_domain_origin) / (
        transform_domain_size - 1)

    bspline_histogram_transform = ants.fit_bspline_object_to_scattered_data(
        scattered_data,
        parametric_data, [transform_domain_origin], [transform_domain_spacing],
        [transform_domain_size],
        data_weights=weights)

    transform_domain = np.linspace(0, 1, transform_domain_size)

    normalized_image = ants.iMath(ants.image_clone(image), "Normalize")
    transformed_array = normalized_image.numpy()
    normalized_array = normalized_image.numpy()

    for i in range(len(transform_domain) - 1):
        indices = np.where((normalized_array >= transform_domain[i])
                           & (normalized_array < transform_domain[i + 1]))
        intensities = normalized_array[indices]

        alpha = (intensities - transform_domain[i]) / (
            transform_domain[i + 1] - transform_domain[i])
        xfrm = alpha * (
            bspline_histogram_transform[i + 1] -
            bspline_histogram_transform[i]) + bspline_histogram_transform[i]
        transformed_array[indices] = intensities + xfrm

    transformed_image = (ants.from_numpy(transformed_array,
                                         origin=image.origin,
                                         spacing=image.spacing,
                                         direction=image.direction) *
                         (image.max() - image.min())) + image.min()

    return (transformed_image)
예제 #21
0
arrC = imagCine.view()
arrC = arrC[23,:,:,:]
ED_img = ants.from_numpy(arrC,origin=imagCine.origin,spacing = imagCine.spacing,direction = imagCine.direction)


# ED_img = ants.resample_image(ED_img,[1.1429,1.1429,1.1429])
# ED_Re = ants.to_nibabel(ED_img)
# ED_Re.set_data_dtype('uint16')
# nib.save(ED_Re,'ED/ED_resampled.nii')

ED_imgCH = ED_img.view()
#Registration
regED_vl = ants.registration(ED_img,images_vLong,type_of_transform = "Rigid",aff_metric = 'mattes')
regED_hl = ants.registration(ED_img,images_hLong,type_of_transform = "Rigid",aff_metric = 'mattes')

cloneVL = ants.image_clone(images_vLong)
cloneHL = ants.image_clone(images_hLong)
roi_vl = overROI(cloneVL,regED_vl['fwdtransforms'],ED_img)
roi_hl = overROI(cloneHL,regED_hl['fwdtransforms'],ED_img)
roi_hlArr = roi_hl.view()
roi_hlArr = roi_hlArr == 1
roi_vlArr = roi_vl.view()
roi_vlArr = roi_vlArr == 1

print('ROI extraction Done')

newvl = ants.to_nibabel(regED_vl['warpedmovout'])
newvl.set_data_dtype('int16')
newhl = ants.to_nibabel(regED_hl['warpedmovout'])
newhl.set_data_dtype('int16')
nib.save(newvl,path+'/ED_1301.nii')
예제 #22
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                use_ensemble=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    The original implementation used global thresholding as a quick
    brain extraction approach.  Due to possible generalization difficulties,
    we leave such post-processing steps to the user.  For brain or white
    matter masking see functions brain_extraction or deep_atropos,
    respectively.

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size
    from ..utilities import preprocess_brain_image
    from ..utilities import binary_dice_coefficient

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (200, 200)

    ################################
    #
    # Preprocess images
    #
    ################################

    def closest_simplified_direction_matrix(direction):
        closest = (np.abs(direction) + 0.5).astype(int).astype(float)
        closest[direction < 0] *= -1.0
        return closest

    simplified_direction = closest_simplified_direction_matrix(flair.direction)

    flair_preprocessing = preprocess_brain_image(
        flair,
        truncate_intensity=None,
        brain_extraction_modality=None,
        do_bias_correction=False,
        do_denoising=False,
        antsxnet_cache_directory=antsxnet_cache_directory,
        verbose=verbose)
    flair_preprocessed = flair_preprocessing["preprocessed_image"]
    flair_preprocessed.set_direction(simplified_direction)
    flair_preprocessed.set_origin((0, 0, 0))
    flair_preprocessed.set_spacing((1, 1, 1))
    number_of_channels = 1

    t1_preprocessed = None
    if t1 is not None:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality=None,
            do_bias_correction=False,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        t1_preprocessed.set_direction(simplified_direction)
        t1_preprocessed.set_origin((0, 0, 0))
        t1_preprocessed.set_spacing((1, 1, 1))
        number_of_channels = 2

    ################################
    #
    # Reorient images
    #
    ################################

    reference_image = ants.make_image((256, 256, 256),
                                      voxval=0,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = np.floor(
        ants.get_center_of_mass(reference_image * 0 + 1))
    center_of_mass_image = np.floor(
        ants.get_center_of_mass(flair_preprocessed))
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm,
        flair_preprocessed,
        reference_image,
        interpolation="nearestneighbor")
    crop_image = ants.image_clone(flair_preprocessed) * 0 + 1
    crop_image_warped = ants.apply_ants_transform_to_image(
        xfrm, crop_image, reference_image, interpolation="nearestneighbor")
    flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped,
                                                crop_image_warped, 1)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm,
            t1_preprocessed,
            reference_image,
            interpolation="nearestneighbor")
        t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped,
                                                 crop_image_warped, 1)

    ################################
    #
    # Gaussian normalize intensity
    #
    ################################

    mean_flair = flair_preprocessed.mean()
    std_flair = flair_preprocessed.std()
    if number_of_channels == 2:
        mean_t1 = t1_preprocessed.mean()
        std_t1 = t1_preprocessed.std()

    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair
    if number_of_channels == 2:
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("White matter hyperintensity:  retrieving model weights.")

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model = create_sysu_media_unet_model_2d(
            (*image_size, number_of_channels))
        unet_loss = binary_dice_coefficient(smoothing_factor=1.)
        unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4),
                           loss=unet_loss)
        unet_model.load_weights(weights_file_name)
        unet_models.append(unet_model)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros(
        (total_number_of_slices, *image_size, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), image_size)
                batchX[slice_count, :, :, 1] = t1_slice.numpy()
            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(np.transpose(batchX,
                                                     axes=(0, 2, 1, 3)),
                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(np.transpose(batchX,
                                                              axes=(0, 2, 1,
                                                                    3)),
                                                 verbose=verbose)
    prediction /= number_of_models
    prediction = np.transpose(prediction, axes=(0, 2, 1, 3))

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        flair_preprocessed)
    probability_image = ants.copy_image_info(flair, probability_image)

    return (probability_image)
예제 #23
0
def deep_atropos(t1,
                 do_preprocessing=True,
                 output_directory=None,
                 verbose=False):
    """
    Six-tissue segmentation.

    Perform Atropos-style six tissue segmentation using deep learning.

    The labeling is as follows:
    Label 0 :  background
    Label 1 :  CSF
    Label 2 :  gray matter
    Label 3 :  white matter
    Label 4 :  deep gray matter
    Label 5 :  brain stem
    Label 6 :  cerebellum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    output_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        tempfile.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_atropos(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            output_directory=output_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Build model and load weights
    #
    ################################

    patch_size = (112, 112, 112)
    stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                     t1_preprocessed.shape[1] - patch_size[1],
                     t1_preprocessed.shape[2] - patch_size[2])

    classes = ("background", "csf", "gray matter", "white matter",
               "deep gray matter", "brain stem", "cerebellum")
    labels = (0, 1, 2, 3, 4, 5, 6)

    unet_model = create_unet_model_3d((*patch_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    if output_directory is not None:
        weights_file_name = output_directory + "/sixTissueOctantSegmentationWeights.h5"
        if not os.path.exists(weights_file_name):
            if verbose == True:
                print("Deep Atropos:  downloading model weights.")
            weights_file_name = get_pretrained_network(
                "sixTissueBrainSegmentation", weights_file_name)
    else:
        weights_file_name = get_pretrained_network(
            "sixTissueOctantBrainSegmentation")

    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    t1_preprocessed = (t1_preprocessed -
                       t1_preprocessed.mean()) / t1_preprocessed.std()
    image_patches = extract_image_patches(t1_preprocessed,
                                          patch_size=patch_size,
                                          max_number_of_patches="all",
                                          stride_length=stride_length,
                                          return_as_array=True)
    batchX = np.expand_dims(image_patches, axis=-1)
    predicted_data = unet_model.predict(batchX, verbose=verbose)

    probability_images = list()
    for i in range(len(labels)):
        print("Reconstructing image", classes[i])
        reconstructed_image = reconstruct_image_from_patches(
            predicted_data[:, :, :, :, i],
            domain_image=t1_preprocessed,
            stride_length=stride_length)

        if do_preprocessing == True:
            probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            probability_images.append(reconstructed_image)

    image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    return_dict = {
        'segmentation_image': relabeled_image,
        'probability_images': probability_images
    }
    return (return_dict)
예제 #24
0
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import crop_image_center

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if return_probability_images == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    else:
        return (dkt_label_image)
예제 #25
0
def longitudinal_cortical_thickness(t1s,
                                    initial_template="oasis",
                                    number_of_iterations=1,
                                    refinement_transform="antsRegistrationSyNQuick[a]",
                                    antsxnet_cache_directory=None,
                                    verbose=False):

    """
    Perform KellyKapowski cortical thickness longitudinally using \code{deepAtropos}
    for segmentation of the derived single-subject template.  It takes inspiration from
    the work described here:

    https://pubmed.ncbi.nlm.nih.gov/31356207/

    Arguments
    ---------
    t1s : list of ANTsImage
        Input list of 3-D unprocessed t1-weighted brain images from a single subject.

    initial_template : string or ANTsImage
        Input image to define the orientation of the SST.  Can be a string (see
        get_antsxnet_data) or a specified template.  This allows the user to create a
        SST outside of this routine.

    number_of_iterations : int
        Defines the number of iterations for refining the SST.

    refinement_transform : string
       Transform for defining the refinement registration transform. See options in
       ants.registration.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Cortical thickness image and segmentation probability images.

    Example
    -------
    >>> t1s = list()
    >>> t1s.append(ants.image_read("t1w_image.nii.gz"))
    >>> kk = longitudinal_cortical_thickness(image)
    """

    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import deep_atropos

    ###################
    #
    #  Initial SST + optional affine refinement
    #
    ##################

    sst = None
    if isinstance(initial_template, str):
        template_file_name_path = get_antsxnet_data(initial_template, antsxnet_cache_directory=antsxnet_cache_directory)
        sst = ants.image_read(template_file_name_path)
    else:
        sst = initial_template

    for s in range(number_of_iterations):
        if verbose:
            print("Refinement iteration", s, "( out of", number_of_iterations, ")")

        sst_tmp = ants.image_clone(sst) * 0
        for i in range(len(t1s)):
            if verbose:
                print("***************************")
                print( "SST processing image", i, "( out of", len(t1s), ")")
                print( "***************************" )
            transform_type = "antsRegistrationSyNQuick[r]"
            if s > 0:
                transform_type = refinement_transform
            t1_preprocessed = preprocess_brain_image(t1s[i],
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                template=sst,
                template_transform_type=transform_type,
                do_bias_correction=False,
                do_denoising=False,
                intensity_normalization_type="01",
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            sst_tmp += t1_preprocessed['preprocessed_image']

        sst = sst_tmp / len(t1s)

    ###################
    #
    #  Preprocessing and affine transform to final SST
    #
    ##################

    t1s_preprocessed = list()
    for i in range(len(t1s)):
        if verbose:
            print("***************************")
            print( "Final processing image", i, "( out of", len(t1s), ")")
            print( "***************************" )
        t1_preprocessed = preprocess_brain_image(t1s[i],
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template=sst,
            template_transform_type="antsRegistrationSyNQuick[a]",
            do_bias_correction=True,
            do_denoising=True,
            intensity_normalization_type="01",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1s_preprocessed.append(t1_preprocessed)

    ###################
    #
    #  Deep  Atropos of SST for priors
    #
    ##################

    sst_atropos = deep_atropos(sst, do_preprocessing=True,
        antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose)

    ###################
    #
    #  Traditional Atropos + KK for each iamge
    #
    ##################

    return_list = list()
    for i in range(len(t1s_preprocessed)):
        if verbose:
            print("Atropos for image", i, "( out of", len(t1s), ")")
        atropos_output = ants.atropos(t1s_preprocessed[i]['preprocessed_image'],
            x=t1s_preprocessed[i]['brain_mask'], i=sst_atropos['probability_images'][1:7],
            m="[0.1,1x1x1]", c="[5,0]", priorweight=0.5, p="Socrates[1]", verbose=int(verbose))

        kk_segmentation = ants.image_clone(atropos_output['segmentation'])
        kk_segmentation[kk_segmentation == 4] = 3
        gray_matter = atropos_output['probabilityimages'][1]
        white_matter = atropos_output['probabilityimages'][2] + atropos_output['probabilityimages'][3]
        kk = ants.kelly_kapowski(s=kk_segmentation, g=gray_matter, w=white_matter,
            its=45, r=0.025, m=1.5, x=0, verbose=int(verbose))

        t1_dict = {'preprocessed_image' : t1s_preprocessed[i]['preprocessed_image'],
                   'thickness_image' : kk,
                   'segmentation_image' : atropos_output['segmentation'],
                   'csf_probability_image' : atropos_output['probabilityimages'][0],
                   'gray_matter_probability_image' : atropos_output['probabilityimages'][1],
                   'white_matter_probability_image' : atropos_output['probabilityimages'][2],
                   'deep_gray_matter_probability_image' : atropos_output['probabilityimages'][3],
                   'brain_stem_probability_image' : atropos_output['probabilityimages'][4],
                   'cerebellum_probability_image' : atropos_output['probabilityimages'][5],
                   'template_transforms' : t1s_preprocessed[i]['template_transforms']
                  }
        return_list.append(t1_dict)

    return_list.append(sst)

    return(return_list)
예제 #26
0
def mainRegScript(patientPath,SA_name,LA_4CH_name,LA_2CH_name,pathSave,typeRe):
    #Load initial images
    SA = ants.image_read(patientPath+SA_name)
    SA = ants.resample_image(SA,[min(SA.spacing),min(SA.spacing),min(SA.spacing)])
    LA_4CH = ants.image_read(patientPath+LA_4CH_name)
    LA_2CH = ants.image_read(patientPath+LA_2CH_name)

    #Registration
    regSA_4CH = ants.registration(SA,LA_4CH,type_of_transform = typeRe,aff_metric = 'mattes')
    regSA_2CH = ants.registration(SA,LA_2CH,type_of_transform = typeRe,aff_metric = 'mattes')
    print('Registration Done')

    #ROI extraction
    clone4CH = ants.image_clone(LA_4CH)
    clone2CH = ants.image_clone(LA_2CH)
    roi_4CH = overROI(clone4CH,regSA_4CH['fwdtransforms'],SA)
    roi_2CH = overROI(clone2CH,regSA_2CH['fwdtransforms'],SA)
    roi_4CHArr = roi_4CH.view()
    roi_4CHArr = roi_4CHArr == 1
    roi_2CHArr = roi_2CH.view()
    roi_2CHArr = roi_2CHArr == 1

    print('ROI extraction Done')

    new4CH = ants.to_nibabel(regSA_4CH['warpedmovout'])
    new4CH.set_data_dtype('int16')
    new2CH = ants.to_nibabel(regSA_2CH['warpedmovout'])
    new2CH.set_data_dtype('int16')
    nib.save(new4CH,pathSave + '4CH.nii')
    nib.save(new2CH,pathSave + '2CH.nii')

    #Normalization

    new4CH_CH = regSA_4CH['warpedmovout'].view()
    new2CH_CH = regSA_2CH['warpedmovout'].view()
    short_CH = SA.view()
    for t in range(0,new2CH_CH.shape[2]):
        short_CH[:,:,t] = normOver(short_CH[:,:,t],roi_4CHArr[:,:,t]+roi_2CHArr[:,:,t])
        new4CH_CH[:,:,t] = normOver(new4CH_CH[:,:,t],roi_4CHArr[:,:,t])
        new2CH_CH[:,:,t] = normOver(new2CH_CH[:,:,t],roi_2CHArr[:,:,t])


    new4CH = ants.to_nibabel(regSA_4CH['warpedmovout'])
    new4CH.set_data_dtype('int16')
    new2CH = ants.to_nibabel(regSA_2CH['warpedmovout'])
    new2CH.set_data_dtype('int16')
    shortNorm = ants.to_nibabel(SA)
    shortNorm.set_data_dtype('int16')
    nib.save(new4CH,pathSave + 'norm_4CH.nii')
    nib.save(new2CH,pathSave + 'norm_2CH.nii')
    nib.save(shortNorm,pathSave + 'norm_SA.nii')
    print('Normalization Done')

    #Checkboard
    ch_4CH = np.zeros(short_CH.shape)
    ch_2CH =np.zeros(short_CH.shape)
    for t in np.arange(0,short_CH.shape[2]):
        ch_4CH[:,:,t] = cheBoard(short_CH[:,:,t],new4CH_CH[:,:,t],16,roi_4CHArr[:,:,t])
        ch_2CH[:,:,t] = cheBoard(short_CH[:,:,t],new2CH_CH[:,:,t],16,roi_2CHArr[:,:,t])
    chest_4CH = nib.Nifti1Image(ch_4CH,new4CH.affine,new4CH.header)
    chest_2CH = nib.Nifti1Image(ch_2CH,new2CH.affine,new2CH.header)
    print("Checkboard filter applied")
    nib.save(chest_4CH,pathSave + 'chest_4CH.nii')
    nib.save(chest_2CH,pathSave + 'chest_2CH.nii')
    print("Done")
    final = (SA,regSA_4CH,regSA_2CH)
    return final
예제 #27
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                do_preprocessing=True,
                                use_ensemble=True,
                                use_axial_slices_only=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    use_axial_slices_only : boolean
        If True, use original implementation which was trained on axial slices.
        If False, use ANTsXNet variant implementation which applies the slice-by-slice
        models to all 3 dimensions and averages the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import crop_image_center
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    flair_preprocessed = flair
    if do_preprocessing == True:
        flair_preprocessing = preprocess_brain_image(
            flair,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=False,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        flair_preprocessed = flair_preprocessing["preprocessed_image"]

    number_of_channels = 1
    if t1 is not None:
        t1_preprocessed = t1
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(
                t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]
        number_of_channels = 2

    ################################
    #
    # Estimate mask
    #
    ################################

    brain_mask = None
    if verbose == True:
        print("Estimating brain mask.")
    if t1 is not None:
        brain_mask = brain_extraction(t1, modality="t1")
    else:
        brain_mask = brain_extraction(flair, modality="flair")

    reference_image = ants.make_image((200, 200, 200),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))

    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, flair_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, t1_preprocessed, reference_image)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean()
    std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std()
    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair

    if number_of_channels == 2:
        mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
        std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_models.append(
            create_sysu_media_unet_model_2d((200, 200, number_of_channels)))
        unet_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]
    if use_axial_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), (200, 200))
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), (200, 200))
                batchX[slice_count, :, :, 1] = t1_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(batchX, verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(batchX, verbose=verbose)
    prediction /= number_of_models

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice + 1

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average, flair)

    return (probability_image)
예제 #28
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             do_slicewise=True,
             antsxnet_cache_directory=None,
             verbose=False):

    """
    Perform White matter hypterintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{doPreprocessing = TRUE}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    do_slicewise : boolean
        apply 2-D modal along direction of maximal slice thickness.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if t1.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if do_slicewise == False:

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=True,
                template="croppedMni152",
                template_transform_type="AffineFast",
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
                moving=flair_preprocessing["preprocessed_image"],
                transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
            flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        ################################
        #
        # Build model and load weights
        #
        ################################

        patch_size = (112, 112, 112)
        stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                        t1_preprocessed.shape[1] - patch_size[1],
                        t1_preprocessed.shape[2] - patch_size[2])

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_3d((*patch_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=False, add_attention_gating=True)

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("ew_david:  prediction.")

        batchX = np.zeros((8, *patch_size, channel_size))

        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,0] = t1_patches

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,1] = flair_patches

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(len(labels)):
            print("Reconstructing image", classes[i])
            reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
                domain_image=t1_preprocessed, stride_length=stride_length)

            if do_preprocessing == True:
                probability_images.append(ants.apply_transforms(fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']['invtransforms'],
                    whichtoinvert=[True], interpolator="linear", verbose=verbose))
            else:
                probability_images.append(reconstructed_image)

        return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = flair_preprocessing["preprocessed_image"]

        resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0)
            t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0)

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (256, 256)

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 32, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3), deconvolution_kernel_size = (2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=True, add_attention_gating=True)

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationSlicewiseWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        use_coarse_slices_only = True

        spacing = ants.get_spacing(flair_preprocessed)
        dimensions_to_predict = (spacing.index(max(spacing)),)
        if use_coarse_slices_only == False:
            dimensions_to_predict = list(range(3))

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += flair_preprocessed.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = flair_preprocessed.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = flair_slice.numpy()

                t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,1] = t1_slice.numpy()

                slice_count += 1


        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        prediction_image_average = ants.image_clone(flair_preprocessed) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + flair_preprocessed.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)
            prediction_per_dimension = prediction[which_batch_slices,:,:,1]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(flair_preprocessed,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                flair_preprocessed.shape))
            prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1)

            current_start_slice = current_end_slice + 1

        if do_resampling:
            prediction_image_average = ants.resample_image_to_target(prediction_image_average, flair)

        return(prediction_image_average)
예제 #29
0
def lung_extraction(image,
                    modality="proton",
                    antsxnet_cache_directory=None,
                    verbose=False):

    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct", "proton", "protonLobes", 
        "maskLobes", and "ventilation".

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import pad_or_crop_image_to_size

    if image.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        weights_file_name = get_pretrained_network("protonLungMri",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability mask to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images_array}
        return(return_dict)

    if modality == "protonLobes" or modality == "maskLobes":
        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors = ants.image_read(spatial_priors_file_name_path)
        priors_image_list = ants.ndimage_to_list(spatial_priors)

        channel_size = 1 + len(priors_image_list)
        number_of_classification_labels = 1 + len(priors_image_list)

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification", 
            number_of_filters_at_base_layer=16, number_of_layers=4,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",))

        if modality == "protonLobes":
            penultimate_layer = unet_model.layers[-2].output
            outputs2 = Conv3D(filters=1,
                            kernel_size=(1, 1, 1),
                            activation='sigmoid',
                            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)
            unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2])
            weights_file_name = get_pretrained_network("protonLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network("maskLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)
        warped_array = warped_image.numpy()
        if modality == "protonLobes":
            warped_array = (warped_array - warped_array.mean()) / warped_array.std()
        else:
            warped_array[warped_array != 0] = 1
       
        batchX = np.zeros((1, *warped_array.shape, channel_size))
        batchX[0,:,:,:,0] = warped_array
        for i in range(len(priors_image_list)):
            batchX[0,:,:,:,i+1] = priors_image_list[i].numpy()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            if modality == "protonLobes":
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))
            else:
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability images to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        if modality == "protonLobes":
            whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]),
                origin=origin, spacing=spacing, direction=direction)
            whole_lung_mask = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), whole_lung_mask, image)

            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array,
                           'whole_lung_mask_image' : whole_lung_mask}
            return(return_dict)
        else:
            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array}
            return(return_dict)


    elif modality == "ct":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess CT image.")

        def closest_simplified_direction_matrix(direction):
            closest = np.floor(np.abs(direction) + 0.5)
            closest[direction < 0] *= -1.0
            return closest

        simplified_direction = closest_simplified_direction_matrix(image.direction)

        reference_image_size = (128, 128, 128)

        ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0)
        ct_preprocessed[ct_preprocessed < -1000] = -1000
        ct_preprocessed[ct_preprocessed > 400] = 400
        ct_preprocessed.set_direction(simplified_direction)
        ct_preprocessed.set_origin((0, 0, 0))
        ct_preprocessed.set_spacing((1, 1, 1))

        ################################
        #
        # Reorient image
        #
        ################################

        reference_image = ants.make_image(reference_image_size,
                                          voxval=0,
                                          spacing=(1, 1, 1),
                                          origin=(0, 0, 0),
                                          direction=np.identity(3))
        center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1))
        center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1))
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_reference), translation=translation)
        ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) /
            (ct_preprocessed.max() - ct_preprocessed.min()))
        ct_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor")
        ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) /
            (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5

        ################################
        #
        # Build models and load weights
        #
        ################################

        if verbose == True:
            print("Build model and load weights.")

        weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left lung", "right lung", "airways")
        number_of_classification_labels = len(classes)

        luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors")))
        for i in range(len(luna16_priors)):
            luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True)
        channel_size = len(luna16_priors) + 1

        unet_model = create_unet_model_3d((*reference_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification",
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5, additional_options=("attentionGating",))
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("Prediction.")

        batchX = np.zeros((1, *reference_image_size, channel_size))
        batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy()

        for i in range(len(luna16_priors)):
            batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(number_of_classification_labels):
            if verbose == True:
                print("Reconstructing image", classes[i])
            probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]),
                origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing,
                direction=ct_preprocessed_warped.direction)
            probability_image = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed)
            probability_image = ants.resample_image(probability_image,
               resample_params=image.shape, use_voxels=True, interp_type=0)
            probability_image = ants.copy_image_info(image, probability_image)
            probability_images.append(probability_image)

        image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images}
        return(return_dict)

    elif modality == "ventilation":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess ventilation image.")

        template_size = (256, 256)

        image_modalities = ("Ventilation",)
        channel_size = len(image_modalities)

        preprocessed_image = (image - image.mean()) / image.std()
        ants.set_direction(preprocessed_image, np.identity(3))

        ################################
        #
        # Build models and load weights
        #
        ################################

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs=1, mode='sigmoid',
            number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0,
            convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2),
            weight_decay=0)

        if verbose == True:
            print("Whole lung mask: retrieving model weights.")

        weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        spacing = ants.get_spacing(preprocessed_image)
        dimensions_to_predict = (spacing.index(max(spacing)),)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = ventilation_slice.numpy()
                slice_count += 1

        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        probability_image = ants.image_clone(image) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)

            prediction_per_dimension = prediction[which_batch_slices,:,:,0]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                image.shape))
            probability_image = probability_image + (prediction_image - probability_image) / (d + 1)

            current_start_slice = current_end_slice + 1

        return(probability_image)

    else:
        return ValueError("Unrecognized modality.")
예제 #30
0
def claustrum_segmentation(t1,
                           do_preprocessing=True,
                           use_ensemble=True,
                           antsxnet_cache_directory=None,
                           verbose=False):
    """
    Claustrum segmentation

    Described here:

        https://arxiv.org/abs/2008.03465

    with the implementation available at:

        https://github.com/hongweilibran/claustrum_multi_view


    Arguments
    ---------
    t1 : ANTsImage
        input 3-D T1 brain image.

    do_preprocessing : boolean
        perform n4 bias correction.

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Claustrum segmentation probability image

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> probability_mask = claustrum_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (180, 180)

    ################################
    #
    # Preprocess images
    #
    ################################

    number_of_channels = 1
    t1_preprocessed = ants.image_clone(t1)
    brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        brain_mask = t1_preprocessing["brain_mask"]

    reference_image = ants.make_image((170, 256, 256),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    t1_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, t1_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
    std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
    t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("Claustrum:  retrieving axial model weights.")

    unet_axial_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_axial_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_axial_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_axial_models[i].load_weights(weights_file_name)

    if verbose == True:
        print("Claustrum:  retrieving coronal model weights.")

    unet_coronal_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_coronal_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_coronal_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_coronal_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [1, 2]

    batch_coronal_X = np.zeros(
        (t1_preprocessed_warped.shape[1], *image_size, number_of_channels))
    batch_axial_X = np.zeros(
        (t1_preprocessed_warped.shape[2], *image_size, number_of_channels))

    for d in range(len(dimensions_to_predict)):
        number_of_slices = t1_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            t1_slice = pad_or_crop_image_to_size(
                ants.slice_image(t1_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            if dimensions_to_predict[d] == 1:
                batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1)
            else:
                batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy())

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Coronal prediction.")

    prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X,
                                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_coronal += unet_coronal_models[i].predict(
                batch_coronal_X, verbose=verbose)
    prediction_coronal /= number_of_models

    for i in range(t1_preprocessed_warped.shape[1]):
        prediction_coronal[i, :, :, 0] = np.rot90(
            np.squeeze(prediction_coronal[i, :, :, 0]))

    if verbose == True:
        print("Axial prediction.")

    prediction_axial = unet_axial_models[0].predict(batch_axial_X,
                                                    verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_axial += unet_axial_models[i].predict(batch_axial_X,
                                                             verbose=verbose)
    prediction_axial /= number_of_models

    for i in range(t1_preprocessed_warped.shape[2]):
        prediction_axial[i, :, :,
                         0] = np.rot90(np.squeeze(prediction_axial[i, :, :,
                                                                   0]),
                                       k=-1)

    if verbose == True:
        print("Restack image and transform back to native space.")

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0

    for d in range(len(dimensions_to_predict)):
        which_batch_slices = range(
            t1_preprocessed_warped.shape[dimensions_to_predict[d]])
        prediction_per_dimension = None
        if dimensions_to_predict[d] == 1:
            prediction_per_dimension = prediction_coronal[
                which_batch_slices, :, :, :]
        else:
            prediction_per_dimension = prediction_axial[
                which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            t1_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      t1_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0)

    return (probability_image)