Esempio n. 1
0
 def test_invert_ants_transform(self):
     img = ants.image_read(ants.get_ants_data("r16")).clone('float')
     tx = ants.new_ants_transform(dimension=2)
     tx.set_parameters((0.9, 0, 0, 1.1, 10, 11))
     img_transformed = tx.apply_to_image(img, img)
     inv_tx = ants.invert_ants_transform(tx)
     img_orig = inv_tx.apply_to_image(img_transformed, img_transformed)
Esempio n. 2
0
def lung_extraction(image,
                    modality="proton",
                    antsxnet_cache_directory=None,
                    verbose=False):

    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct", "proton", "protonLobes", 
        "maskLobes", and "ventilation".

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import pad_or_crop_image_to_size

    if image.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        weights_file_name = get_pretrained_network("protonLungMri",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability mask to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images_array}
        return(return_dict)

    if modality == "protonLobes" or modality == "maskLobes":
        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors = ants.image_read(spatial_priors_file_name_path)
        priors_image_list = ants.ndimage_to_list(spatial_priors)

        channel_size = 1 + len(priors_image_list)
        number_of_classification_labels = 1 + len(priors_image_list)

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification", 
            number_of_filters_at_base_layer=16, number_of_layers=4,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",))

        if modality == "protonLobes":
            penultimate_layer = unet_model.layers[-2].output
            outputs2 = Conv3D(filters=1,
                            kernel_size=(1, 1, 1),
                            activation='sigmoid',
                            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)
            unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2])
            weights_file_name = get_pretrained_network("protonLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network("maskLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)
        warped_array = warped_image.numpy()
        if modality == "protonLobes":
            warped_array = (warped_array - warped_array.mean()) / warped_array.std()
        else:
            warped_array[warped_array != 0] = 1
       
        batchX = np.zeros((1, *warped_array.shape, channel_size))
        batchX[0,:,:,:,0] = warped_array
        for i in range(len(priors_image_list)):
            batchX[0,:,:,:,i+1] = priors_image_list[i].numpy()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            if modality == "protonLobes":
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))
            else:
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability images to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        if modality == "protonLobes":
            whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]),
                origin=origin, spacing=spacing, direction=direction)
            whole_lung_mask = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), whole_lung_mask, image)

            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array,
                           'whole_lung_mask_image' : whole_lung_mask}
            return(return_dict)
        else:
            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array}
            return(return_dict)


    elif modality == "ct":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess CT image.")

        def closest_simplified_direction_matrix(direction):
            closest = np.floor(np.abs(direction) + 0.5)
            closest[direction < 0] *= -1.0
            return closest

        simplified_direction = closest_simplified_direction_matrix(image.direction)

        reference_image_size = (128, 128, 128)

        ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0)
        ct_preprocessed[ct_preprocessed < -1000] = -1000
        ct_preprocessed[ct_preprocessed > 400] = 400
        ct_preprocessed.set_direction(simplified_direction)
        ct_preprocessed.set_origin((0, 0, 0))
        ct_preprocessed.set_spacing((1, 1, 1))

        ################################
        #
        # Reorient image
        #
        ################################

        reference_image = ants.make_image(reference_image_size,
                                          voxval=0,
                                          spacing=(1, 1, 1),
                                          origin=(0, 0, 0),
                                          direction=np.identity(3))
        center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1))
        center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1))
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_reference), translation=translation)
        ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) /
            (ct_preprocessed.max() - ct_preprocessed.min()))
        ct_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor")
        ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) /
            (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5

        ################################
        #
        # Build models and load weights
        #
        ################################

        if verbose == True:
            print("Build model and load weights.")

        weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left lung", "right lung", "airways")
        number_of_classification_labels = len(classes)

        luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors")))
        for i in range(len(luna16_priors)):
            luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True)
        channel_size = len(luna16_priors) + 1

        unet_model = create_unet_model_3d((*reference_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification",
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5, additional_options=("attentionGating",))
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("Prediction.")

        batchX = np.zeros((1, *reference_image_size, channel_size))
        batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy()

        for i in range(len(luna16_priors)):
            batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(number_of_classification_labels):
            if verbose == True:
                print("Reconstructing image", classes[i])
            probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]),
                origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing,
                direction=ct_preprocessed_warped.direction)
            probability_image = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed)
            probability_image = ants.resample_image(probability_image,
               resample_params=image.shape, use_voxels=True, interp_type=0)
            probability_image = ants.copy_image_info(image, probability_image)
            probability_images.append(probability_image)

        image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images}
        return(return_dict)

    elif modality == "ventilation":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess ventilation image.")

        template_size = (256, 256)

        image_modalities = ("Ventilation",)
        channel_size = len(image_modalities)

        preprocessed_image = (image - image.mean()) / image.std()
        ants.set_direction(preprocessed_image, np.identity(3))

        ################################
        #
        # Build models and load weights
        #
        ################################

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs=1, mode='sigmoid',
            number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0,
            convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2),
            weight_decay=0)

        if verbose == True:
            print("Whole lung mask: retrieving model weights.")

        weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        spacing = ants.get_spacing(preprocessed_image)
        dimensions_to_predict = (spacing.index(max(spacing)),)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = ventilation_slice.numpy()
                slice_count += 1

        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        probability_image = ants.image_clone(image) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)

            prediction_per_dimension = prediction[which_batch_slices,:,:,0]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                image.shape))
            probability_image = probability_image + (prediction_image - probability_image) / (d + 1)

            current_start_slice = current_end_slice + 1

        return(probability_image)

    else:
        return ValueError("Unrecognized modality.")
Esempio n. 3
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                do_preprocessing=True,
                                use_ensemble=True,
                                use_axial_slices_only=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    use_axial_slices_only : boolean
        If True, use original implementation which was trained on axial slices.
        If False, use ANTsXNet variant implementation which applies the slice-by-slice
        models to all 3 dimensions and averages the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import crop_image_center
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    flair_preprocessed = flair
    if do_preprocessing == True:
        flair_preprocessing = preprocess_brain_image(
            flair,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=False,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        flair_preprocessed = flair_preprocessing["preprocessed_image"]

    number_of_channels = 1
    if t1 is not None:
        t1_preprocessed = t1
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(
                t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]
        number_of_channels = 2

    ################################
    #
    # Estimate mask
    #
    ################################

    brain_mask = None
    if verbose == True:
        print("Estimating brain mask.")
    if t1 is not None:
        brain_mask = brain_extraction(t1, modality="t1")
    else:
        brain_mask = brain_extraction(flair, modality="flair")

    reference_image = ants.make_image((200, 200, 200),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))

    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, flair_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, t1_preprocessed, reference_image)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean()
    std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std()
    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair

    if number_of_channels == 2:
        mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
        std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_models.append(
            create_sysu_media_unet_model_2d((200, 200, number_of_channels)))
        unet_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]
    if use_axial_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), (200, 200))
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), (200, 200))
                batchX[slice_count, :, :, 1] = t1_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(batchX, verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(batchX, verbose=verbose)
    prediction /= number_of_models

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice + 1

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average, flair)

    return (probability_image)
probability_mask_refine_stage_resampled = ants.from_numpy(
    prediction_refine_stage_array,
    origin=image_resampled.origin,
    spacing=image_resampled.spacing,
    direction=image_resampled.direction)
probability_mask_refine_stage = ants.resample_image_to_target(
    probability_mask_refine_stage_resampled, image)
ants.image_write(probability_mask_refine_stage, output_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Renormalize to native space")
start_time = time.time()
probability_image = ants.apply_ants_transform_to_image(
    ants.invert_ants_transform(xfrm), probability_mask_refine_stage, image)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Writing", output_file_name)
start_time = time.time()
ants.image_write(probability_image, output_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

end_time_total = time.time()
elapsed_time_total = end_time_total - start_time_total
print("  (Total elapsed time: ", elapsed_time, " seconds)")
Esempio n. 5
0
def lung_extraction(image,
                    modality="proton",
                    output_directory=None,
                    verbose=None):
    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct" and "proton".

    output_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        tempfile.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network

    if image.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        if output_directory is not None:
            weights_file_name = output_directory + "/protonLungSegmentationWeights.h5"
            if not os.path.exists(weights_file_name):
                if verbose == True:
                    print("Lung extraction:  downloading weights.")
                weights_file_name = get_pretrained_network(
                    "protonLungMri", weights_file_name)
        else:
            weights_file_name = get_pretrained_network("protonLungMri")

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name = None
        reorient_template_file_exists = False
        if output_directory is not None:
            reorient_template_file_name = output_directory + "/protonLungTemplate.nii.gz"
            if os.path.exists(reorient_template_file_name):
                reorient_template_file_exists = True

        reorient_template = None
        if output_directory is None or reorient_template_file_exists == False:
            reorient_template_file = tempfile.NamedTemporaryFile(
                suffix=".nii.gz")
            reorient_template_file.close()
            template_file_name = reorient_template_file.name
            template_url = "https://ndownloader.figshare.com/files/22707338"

            if not os.path.exists(template_file_name):
                if verbose == True:
                    print("Lung extraction:  downloading template.")
                r = requests.get(template_url)
                with open(template_file_name, 'wb') as f:
                    f.write(r.content)
            reorient_template = ants.image_read(template_file_name)
            if output_directory is not None:
                shutil.copy(template_file_name, reorient_template_file_name)
        else:
            reorient_template = ants.image_read(reorient_template_file_name)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4,
            number_of_filters_at_base_layer=16,
            dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5),
            deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template *
                                                          0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)
        warped_image = ants.apply_ants_transform_to_image(
            xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=0)

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                                origin=origin,
                                spacing=spacing,
                                direction=direction))

        if verbose == True:
            print(
                "Lung extraction:  renormalize probability mask to native space."
            )

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i],
                image)

        image_matrix = ants.image_list_to_matrix(probability_images_array,
                                                 image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {
            'segmentation_image': segmentation_image,
            'probability_images': probability_images_array
        }
        return (return_dict)

    elif modality == "ct":
        if output_directory is not None:
            weights_file_name = output_directory + "/ctLungSegmentationWeights.h5"
            if not os.path.exists(weights_file_name):
                if verbose == True:
                    print("Lung extraction:  downloading weights.")
                weights_file_name = get_pretrained_network(
                    "ctHumanLung", weights_file_name)
        else:
            weights_file_name = get_pretrained_network("ctHumanLung")

        classes = ("background", "left_lung", "right_lung", "trachea")
        number_of_classification_labels = len(classes)

        reorient_template_file_name = None
        reorient_template_file_exists = False
        if output_directory is not None:
            reorient_template_file_name = output_directory + "/ctLungTemplate.nii.gz"
            if os.path.exists(reorient_template_file_name):
                reorient_template_file_exists = True

        reorient_template = None
        if output_directory is None or reorient_template_file_exists == False:
            reorient_template_file = tempfile.NamedTemporaryFile(
                suffix=".nii.gz")
            reorient_template_file.close()
            template_file_name = reorient_template_file.name
            template_url = "https://ndownloader.figshare.com/files/22707335"

            if not os.path.exists(template_file_name):
                if verbose == True:
                    print("Lung extraction:  downloading template.")
                r = requests.get(template_url)
                with open(template_file_name, 'wb') as f:
                    f.write(r.content)
            reorient_template = ants.image_read(template_file_name)
            if output_directory is not None:
                shutil.copy(template_file_name, reorient_template_file_name)
        else:
            reorient_template = ants.image_read(reorient_template_file_name)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template *
                                                          0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)
        warped_image = ants.apply_ants_transform_to_image(
            xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=0)

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                                origin=origin,
                                spacing=spacing,
                                direction=direction))

        if verbose == True:
            print(
                "Lung extraction:  renormalize probability mask to native space."
            )

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i],
                image)

        image_matrix = ants.image_list_to_matrix(probability_images_array,
                                                 image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {
            'segmentation_image': segmentation_image,
            'probability_images': probability_images_array
        }
        return (return_dict)
Esempio n. 6
0
def claustrum_segmentation(t1,
                           do_preprocessing=True,
                           use_ensemble=True,
                           antsxnet_cache_directory=None,
                           verbose=False):
    """
    Claustrum segmentation

    Described here:

        https://arxiv.org/abs/2008.03465

    with the implementation available at:

        https://github.com/hongweilibran/claustrum_multi_view


    Arguments
    ---------
    t1 : ANTsImage
        input 3-D T1 brain image.

    do_preprocessing : boolean
        perform n4 bias correction.

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Claustrum segmentation probability image

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> probability_mask = claustrum_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (180, 180)

    ################################
    #
    # Preprocess images
    #
    ################################

    number_of_channels = 1
    t1_preprocessed = ants.image_clone(t1)
    brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        brain_mask = t1_preprocessing["brain_mask"]

    reference_image = ants.make_image((170, 256, 256),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    t1_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, t1_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
    std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
    t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("Claustrum:  retrieving axial model weights.")

    unet_axial_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_axial_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_axial_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_axial_models[i].load_weights(weights_file_name)

    if verbose == True:
        print("Claustrum:  retrieving coronal model weights.")

    unet_coronal_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_coronal_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_coronal_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_coronal_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [1, 2]

    batch_coronal_X = np.zeros(
        (t1_preprocessed_warped.shape[1], *image_size, number_of_channels))
    batch_axial_X = np.zeros(
        (t1_preprocessed_warped.shape[2], *image_size, number_of_channels))

    for d in range(len(dimensions_to_predict)):
        number_of_slices = t1_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            t1_slice = pad_or_crop_image_to_size(
                ants.slice_image(t1_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            if dimensions_to_predict[d] == 1:
                batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1)
            else:
                batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy())

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Coronal prediction.")

    prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X,
                                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_coronal += unet_coronal_models[i].predict(
                batch_coronal_X, verbose=verbose)
    prediction_coronal /= number_of_models

    for i in range(t1_preprocessed_warped.shape[1]):
        prediction_coronal[i, :, :, 0] = np.rot90(
            np.squeeze(prediction_coronal[i, :, :, 0]))

    if verbose == True:
        print("Axial prediction.")

    prediction_axial = unet_axial_models[0].predict(batch_axial_X,
                                                    verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_axial += unet_axial_models[i].predict(batch_axial_X,
                                                             verbose=verbose)
    prediction_axial /= number_of_models

    for i in range(t1_preprocessed_warped.shape[2]):
        prediction_axial[i, :, :,
                         0] = np.rot90(np.squeeze(prediction_axial[i, :, :,
                                                                   0]),
                                       k=-1)

    if verbose == True:
        print("Restack image and transform back to native space.")

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0

    for d in range(len(dimensions_to_predict)):
        which_batch_slices = range(
            t1_preprocessed_warped.shape[dimensions_to_predict[d]])
        prediction_per_dimension = None
        if dimensions_to_predict[d] == 1:
            prediction_per_dimension = prediction_coronal[
                which_batch_slices, :, :, :]
        else:
            prediction_per_dimension = prediction_axial[
                which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            t1_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      t1_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0)

    return (probability_image)
Esempio n. 7
0
for i in range(number_of_classification_labels):
    probability_images_array.append(
        ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                        origin=origin,
                        spacing=spacing,
                        direction=direction))

end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Renormalize to native space")
start_time = time.time()
for i in range(number_of_classification_labels):
    probability_images_array[i] = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), probability_images_array[i], image)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

for i in range(1, number_of_classification_labels):
    print("Writing", classes[i])
    start_time = time.time()
    ants.image_write(
        probability_images_array[i],
        output_file_name_prefix + classes[i] + "Probability.nii.gz")
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("  (elapsed time: ", elapsed_time, " seconds)")

end_time_total = time.time()
Esempio n. 8
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                use_ensemble=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    The original implementation used global thresholding as a quick
    brain extraction approach.  Due to possible generalization difficulties,
    we leave such post-processing steps to the user.  For brain or white
    matter masking see functions brain_extraction or deep_atropos,
    respectively.

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size
    from ..utilities import preprocess_brain_image
    from ..utilities import binary_dice_coefficient

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (200, 200)

    ################################
    #
    # Preprocess images
    #
    ################################

    def closest_simplified_direction_matrix(direction):
        closest = (np.abs(direction) + 0.5).astype(int).astype(float)
        closest[direction < 0] *= -1.0
        return closest

    simplified_direction = closest_simplified_direction_matrix(flair.direction)

    flair_preprocessing = preprocess_brain_image(
        flair,
        truncate_intensity=None,
        brain_extraction_modality=None,
        do_bias_correction=False,
        do_denoising=False,
        antsxnet_cache_directory=antsxnet_cache_directory,
        verbose=verbose)
    flair_preprocessed = flair_preprocessing["preprocessed_image"]
    flair_preprocessed.set_direction(simplified_direction)
    flair_preprocessed.set_origin((0, 0, 0))
    flair_preprocessed.set_spacing((1, 1, 1))
    number_of_channels = 1

    t1_preprocessed = None
    if t1 is not None:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality=None,
            do_bias_correction=False,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        t1_preprocessed.set_direction(simplified_direction)
        t1_preprocessed.set_origin((0, 0, 0))
        t1_preprocessed.set_spacing((1, 1, 1))
        number_of_channels = 2

    ################################
    #
    # Reorient images
    #
    ################################

    reference_image = ants.make_image((256, 256, 256),
                                      voxval=0,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = np.floor(
        ants.get_center_of_mass(reference_image * 0 + 1))
    center_of_mass_image = np.floor(
        ants.get_center_of_mass(flair_preprocessed))
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm,
        flair_preprocessed,
        reference_image,
        interpolation="nearestneighbor")
    crop_image = ants.image_clone(flair_preprocessed) * 0 + 1
    crop_image_warped = ants.apply_ants_transform_to_image(
        xfrm, crop_image, reference_image, interpolation="nearestneighbor")
    flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped,
                                                crop_image_warped, 1)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm,
            t1_preprocessed,
            reference_image,
            interpolation="nearestneighbor")
        t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped,
                                                 crop_image_warped, 1)

    ################################
    #
    # Gaussian normalize intensity
    #
    ################################

    mean_flair = flair_preprocessed.mean()
    std_flair = flair_preprocessed.std()
    if number_of_channels == 2:
        mean_t1 = t1_preprocessed.mean()
        std_t1 = t1_preprocessed.std()

    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair
    if number_of_channels == 2:
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("White matter hyperintensity:  retrieving model weights.")

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model = create_sysu_media_unet_model_2d(
            (*image_size, number_of_channels))
        unet_loss = binary_dice_coefficient(smoothing_factor=1.)
        unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4),
                           loss=unet_loss)
        unet_model.load_weights(weights_file_name)
        unet_models.append(unet_model)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros(
        (total_number_of_slices, *image_size, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), image_size)
                batchX[slice_count, :, :, 1] = t1_slice.numpy()
            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(np.transpose(batchX,
                                                     axes=(0, 2, 1, 3)),
                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(np.transpose(batchX,
                                                              axes=(0, 2, 1,
                                                                    3)),
                                                 verbose=verbose)
    prediction /= number_of_models
    prediction = np.transpose(prediction, axes=(0, 2, 1, 3))

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        flair_preprocessed)
    probability_image = ants.copy_image_info(flair, probability_image)

    return (probability_image)
Esempio n. 9
0
def data_augmentation(input_image_list,
                      segmentation_image_list=None,
                      pointset_list=None,
                      number_of_simulations=10,
                      reference_image=None,
                      transform_type='affineAndDeformation',
                      noise_model='additivegaussian',
                      noise_parameters=(0.0, 0.05),
                      sd_simulated_bias_field=0.05,
                      sd_histogram_warping=0.05,
                      sd_affine=0.05,
                      output_numpy_file_prefix=None,
                      verbose=False):
    """
    Randomly transform image data.

    Given an input image list (possibly multi-modal) and an optional corresponding
    segmentation image list, this function will perform data augmentation with
    the following augmentation possibilities:

    * spatial transformations
    * added image noise
    * simulated bias field
    * histogram warping

    Arguments
    ---------

    input_image_list : list of lists of ANTsImages
        List of lists of input images to warp.  The internal list sets contain one
        or more images (per subject) which are assumed to be mutually aligned.  The
        outer list contains multiple subject lists which are randomly sampled to
        produce output image list.

    segmentation_image_list : list of ANTsImages
        List of segmentation images corresponding to the input image list (optional).

    pointset_list: list of pointsets
        Numpy arrays corresponding to the input image list (optional).  If using this
        option, the transform_type must be invertible.

    number_of_simulations : integer
        Number of simulated output image sets.  Default = 10.

    reference_image : ANTsImage
        Defines the spatial domain for all output images.  If one is not specified,
        we used the first image in the input image list.

    transform_type : string
        One of the following options: "translation", "rigid", "scaleShear", "affine",
        "deformation", "affineAndDeformation".

    noise_model : string
        'additivegaussian', 'saltandpepper', 'shot', or 'speckle'.

    noise_parameters : tuple or array or float
        'additivegaussian': (mean, standardDeviation)
        'saltandpepper': (probability, saltValue, pepperValue)
        'shot': scale
        'speckle': standardDeviation
        Note that the standard deviation, scale, and probability values are *max* values
        and are randomly selected in the range [0, noise_parameter].  Also, the "mean",
        "saltValue" and "pepperValue" are assumed to be in the intensity normalized range
        of [0, 1].

    sd_simulated_bias_field : float
        Characterize the standard deviation of the amplitude.

    sd_histogram_warping : float
        Determines the strength of the bias field.

    sd_affine : float
        Determines the amount of transformation based change.

    output_numpy_file_prefix : string
        Filename of output numpy array containing all the simulated images and segmentations.

    Returns
    -------
    list of lists of transformed images and/or outputs to a numpy array.

    Example
    -------
    >>> image1_list = list()
    >>> image1_list.append(ants.image_read(ants.get_ants_data("r16")))
    >>> image2_list = list()
    >>> image2_list.append(ants.image_read(ants.get_ants_data("r64")))
    >>> segmentation1 = ants.threshold_image(image1_list[0], "Otsu", 3)
    >>> segmentation2 = ants.threshold_image(image2_list[0], "Otsu", 3)
    >>> input_segmentations = list()
    >>> input_segmentations.append(segmentation1)
    >>> input_segmentations.append(segmentation2)
    >>> points1 = ants.get_centroids(segmentation1)[:,0:2]
    >>> points2 = ants.get_centroids(segmentation2)[:,0:2]
    >>> input_points = list()
    >>> input_points.append(points1)
    >>> input_points.append(points2)
    >>> input_images = list()
    >>> input_images.append(image1_list)
    >>> input_images.append(image2_list)
    >>> data = data_augmentation(input_images,
                                 input_segmentations,
                                 input_points,
                                 tranform_type="scaleShear")
    """

    from ..utilities import histogram_warp_image_intensities
    from ..utilities import simulate_bias_field
    from ..utilities import randomly_transform_image_data

    if reference_image is None:
        reference_image = input_image_list[0][0]

    number_of_modalities = len(input_image_list[0])

    # Set up numpy arrays if outputing to file.

    batch_X = None
    batch_Y = None
    batch_Y_points = None
    number_of_points = 0

    if pointset_list is not None:
        number_of_points = pointset_list[0].shape[0]
        batch_Y_points = np.zeros((number_of_simulations, number_of_points,
                                   reference_image.dimension))

    if output_numpy_file_prefix is not None:
        batch_X = np.zeros((number_of_simulations, *reference_image.shape,
                            number_of_modalities))
        if segmentation_image_list is not None:
            batch_Y = np.zeros((number_of_simulations, *reference_image.shape))

    # Spatially transform input image data

    if verbose:
        print("Randomly spatially transforming the image data.")

    transform_augmentation = randomly_transform_image_data(
        reference_image,
        input_image_list=input_image_list,
        segmentation_image_list=segmentation_image_list,
        number_of_simulations=number_of_simulations,
        transform_type=transform_type,
        sd_affine=sd_affine,
        deformation_transform_type="bspline",
        number_of_random_points=1000,
        sd_noise=2.0,
        number_of_fitting_levels=4,
        mesh_size=1,
        sd_smoothing=4.0,
        input_image_interpolator='linear',
        segmentation_image_interpolator='nearestNeighbor')

    simulated_image_list = list()
    simulated_segmentation_image_list = list()
    simulated_pointset_list = list()

    for i in range(number_of_simulations):

        if verbose:
            print("Processing simulation " + str(i))

        segmentation = None
        if segmentation_image_list is not None:
            segmentation = transform_augmentation[
                'simulated_segmentation_images'][i]
            simulated_segmentation_image_list.append(segmentation)
            if batch_Y is not None:
                if reference_image.dimension == 2:
                    batch_Y[i, :, :] = segmentation.numpy()
                else:
                    batch_Y[i, :, :, :] = segmentation.numpy()

        if pointset_list is not None:
            simulated_transform = transform_augmentation[
                'simulated_transforms'][i]
            simulated_transform_inverse = ants.invert_ants_transform(
                simulated_transform)
            which_subject = transform_augmentation['which_subject'][i]
            simulated_points = np.zeros(
                (number_of_points, reference_image.dimension))
            for j in range(number_of_points):
                simulated_points[j, :] = ants.apply_ants_transform_to_point(
                    simulated_transform_inverse,
                    pointset_list[which_subject][j, :])
            simulated_pointset_list.append(simulated_points)
            if batch_Y_points is not None:
                batch_Y_points[i, :, :] = simulated_points

        simulated_local_image_list = list()
        for j in range(number_of_modalities):

            if verbose:
                print("    Modality " + str(j))

            image = transform_augmentation['simulated_images'][i][j]
            image_range = image.range()

            # Normalize to [0, 1] before applying augmentation

            if verbose:
                print("        Normalizing to [0, 1].")

            image = ants.iMath(image, "Normalize")

            # Noise

            if noise_model is not None:

                if verbose:
                    print("        Adding noise (" + noise_model + ").")

                if any(np.array(noise_parameters) > 0):

                    if noise_model.lower() == "additivegaussian":
                        parameters = (noise_parameters[0],
                                      random.uniform(0.0, noise_parameters[1]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="additivegaussian",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "saltandpepper":
                        parameters = (random.uniform(0.0, noise_parameters[0]),
                                      noise_parameters[1], noise_parameters[2])
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="saltandpepper",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "shot":
                        parameters = (random.uniform(0.0, noise_parameters[0]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="shot",
                            noise_parameters=parameters)
                    elif noise_model.lower() == "speckle":
                        parameters = (random.uniform(0.0, noise_parameters[0]))
                        image = ants.add_noise_to_image(
                            image,
                            noise_model="speckle",
                            noise_parameters=parameters)
                    else:
                        raise ValueError("Unrecognized noise model.")

            # Simulated bias field

            if sd_simulated_bias_field > 0:

                if verbose:
                    print("        Adding simulated bias field.")

                bias_field = simulate_bias_field(
                    image, sd_bias_field=sd_simulated_bias_field)
                image = image * (bias_field + 1)

            # Histogram intensity warping

            if sd_histogram_warping > 0:

                if verbose:
                    print("        Performing intensity histogram warping.")

                break_points = [0.2, 0.4, 0.6, 0.8]
                displacements = list()
                for b in range(len(break_points)):
                    displacements.append(random.gauss(0, sd_histogram_warping))
                image = histogram_warp_image_intensities(
                    image,
                    break_points=break_points,
                    clamp_end_points=(False, False),
                    displacements=displacements)

            # Rescale to original intensity range

            if verbose:
                print("        Rescaling to original intensity range.")

            image = ants.iMath(image, "Normalize") * (
                image_range[1] - image_range[0]) + image_range[0]

            simulated_local_image_list.append(image)

            if batch_X is not None:
                if reference_image.dimension == 2:
                    batch_X[i, :, :, j] = image.numpy()
                else:
                    batch_X[i, :, :, :, j] = image.numpy()

        simulated_image_list.append(simulated_local_image_list)

    if batch_X is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing images to numpy array.")
            np.save(output_numpy_file_prefix + "SimulatedImages.npy", batch_X)
    if batch_Y is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing segmentation images to numpy array.")
            np.save(
                output_numpy_file_prefix + "SimulatedSegmentationImages.npy",
                batch_Y)
    if batch_Y_points is not None:
        if output_numpy_file_prefix is not None:
            if verbose:
                print("Writing segmentation images to numpy array.")
            np.save(output_numpy_file_prefix + "SimulatedPointsets.npy",
                    batch_Y_points)

    if segmentation_image_list is None and pointset_list is None:
        return ({'simulated_images': simulated_image_list})
    elif segmentation_image_list is None:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_pointset_list': simulated_pointset_list
        })
    elif pointset_list is None:
        return ({
            'simulated_images':
            simulated_image_list,
            'simulated_segmentation_images':
            simulated_segmentation_image_list
        })
    else:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_segmentation_images': simulated_segmentation_image_list,
            'simulated_pointset_list': simulated_pointset_list
        })
start_time = time.time()

zeros = np.zeros(warped_image.shape)
zeros_image = ants.from_numpy(zeros,
                              origin=warped_image.origin,
                              spacing=warped_image.spacing,
                              direction=warped_image.direction)

for i in range(number_of_classification_labels):
    probability_image = ants.resample_image(probability_images_array[i],
                                            original_cropped_size,
                                            use_voxels=True,
                                            interp_type=0)
    probability_image = ants.decrop_image(probability_image, zeros_image)
    probability_images_array[i] = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), probability_image, image)

end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

for i in range(1, number_of_classification_labels):
    print("Writing", classes[i])
    start_time = time.time()
    ants.image_write(
        probability_images_array[i],
        output_file_name_prefix + classes[i] + "Segmentation.nii.gz")
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("  (elapsed time: ", elapsed_time, " seconds)")
Esempio n. 11
0
def brain_extraction(image,
                     modality="t1",
                     antsxnet_cache_directory=None,
                     verbose=False):
    """
    Perform brain extraction using U-net and ANTs-based training data.  "NoBrainer"
    is also possible where brain extraction uses U-net and FreeSurfer training data
    ported from the

    https://github.com/neuronets/nobrainer-models

    Arguments
    ---------
    image : ANTsImage
        input image (or list of images for multi-modal scenarios).

    modality : string
        Modality image type.  Options include:
            * "t1": T1-weighted MRI---ANTs-trained.  Update from "t1v0".
            * "t1v0":  T1-weighted MRI---ANTs-trained.
            * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk.
            * "t1combined": Brian's combination of "t1" and "t1nobrainer".  One can also specify
                            "t1combined[X]" where X is the morphological radius.  X = 12 by default.
            * "flair": FLAIR MRI.
            * "t2": T2 MRI.
            * "bold": 3-D BOLD MRI.
            * "fa": Fractional anisotropy.
            * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner.
            * "t1infant": T1-w infant MRI h/t Martin Styner.
            * "t2infant": T2-w infant MRI h/t Martin Styner.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs probability brain mask image.

    Example
    -------
    >>> probability_brain_mask = brain_extraction(brain_image, modality="t1")
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..architectures import create_nobrainer_unet_model_3d

    classes = ("background", "brain")
    number_of_classification_labels = len(classes)

    channel_size = 1
    if isinstance(image, list):
        channel_size = len(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    input_images = list()
    if channel_size == 1:
        input_images.append(image)
    else:
        input_images = image

    if input_images[0].dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if "t1combined" in modality:

        brain_extraction_t1 = brain_extraction(
            image,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_mask = ants.iMath_get_largest_component(
            ants.threshold_image(brain_extraction_t1, 0.5, 10000))

        # Need to change with voxel resolution
        morphological_radius = 12
        if '[' in modality and ']' in modality:
            morphological_radius = int(modality.split("[")[1].split("]")[0])

        brain_extraction_t1nobrainer = brain_extraction(
            image * ants.iMath_MD(brain_mask, radius=morphological_radius),
            modality="t1nobrainer",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_extraction_combined = ants.iMath_fill_holes(
            ants.iMath_get_largest_component(brain_extraction_t1nobrainer *
                                             brain_mask))

        brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(
            brain_mask, morphological_radius) + brain_mask

        return (brain_extraction_combined)

    if modality != "t1nobrainer":

        #####################
        #
        # ANTs-based
        #
        #####################

        weights_file_name_prefix = None

        if modality == "t1v0":
            weights_file_name_prefix = "brainExtraction"
        elif modality == "t1":
            weights_file_name_prefix = "brainExtractionT1"
        elif modality == "t2":
            weights_file_name_prefix = "brainExtractionT2"
        elif modality == "flair":
            weights_file_name_prefix = "brainExtractionFLAIR"
        elif modality == "bold":
            weights_file_name_prefix = "brainExtractionBOLD"
        elif modality == "fa":
            weights_file_name_prefix = "brainExtractionFA"
        elif modality == "t1t2infant":
            weights_file_name_prefix = "brainExtractionInfantT1T2"
        elif modality == "t1infant":
            weights_file_name_prefix = "brainExtractionInfantT1"
        elif modality == "t2infant":
            weights_file_name_prefix = "brainExtractionInfantT2"
        else:
            raise ValueError("Unknown modality type.")

        weights_file_name = get_pretrained_network(
            weights_file_name_prefix,
            antsxnet_cache_directory=antsxnet_cache_directory)

        if verbose == True:
            print("Brain extraction:  retrieving template.")
        reorient_template_file_name_path = get_antsxnet_data(
            "S_template3", antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)
        resampled_image_size = reorient_template.shape

        if modality == "t1":
            classes = ("background", "head", "brain")
            number_of_classification_labels = len(classes)

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Brain extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template)
        center_of_mass_image = ants.get_center_of_mass(input_images[0])
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)

        batchX = np.zeros((1, *resampled_image_size, channel_size))
        for i in range(len(input_images)):
            warped_image = ants.apply_ants_transform_to_image(
                xfrm, input_images[i], reorient_template)
            warped_array = warped_image.numpy()
            batchX[0, :, :, :, i] = (warped_array -
                                     warped_array.mean()) / warped_array.std()

        if verbose == True:
            print("Brain extraction:  prediction and decoding.")

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = reorient_template.origin
        spacing = reorient_template.spacing
        direction = reorient_template.direction

        probability_images_array = list()
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 0]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 1]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        if modality == "t1":
            probability_images_array.append(
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 2]),
                                origin=origin,
                                spacing=spacing,
                                direction=direction))

        if verbose == True:
            print(
                "Brain extraction:  renormalize probability mask to native space."
            )
        probability_image = ants.apply_ants_transform_to_image(
            ants.invert_ants_transform(xfrm),
            probability_images_array[number_of_classification_labels - 1],
            input_images[0])

        return (probability_image)

    else:

        #####################
        #
        # NoBrainer
        #
        #####################

        if verbose == True:
            print("NoBrainer:  generating network.")

        model = create_nobrainer_unet_model_3d((None, None, None, 1))

        weights_file_name = get_pretrained_network(
            "brainExtractionNoBrainer",
            antsxnet_cache_directory=antsxnet_cache_directory)
        model.load_weights(weights_file_name)

        if verbose == True:
            print(
                "NoBrainer:  preprocessing (intensity truncation and resampling)."
            )

        image_array = image.numpy()
        image_robust_range = np.quantile(
            image_array[np.where(image_array != 0)], (0.02, 0.98))
        threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]
                                  ) + image_robust_range[0]

        thresholded_mask = ants.threshold_image(image, -10000, threshold_value,
                                                0, 1)
        thresholded_image = image * thresholded_mask

        image_resampled = ants.resample_image(thresholded_image,
                                              (256, 256, 256),
                                              use_voxels=True)
        image_array = np.expand_dims(image_resampled.numpy(), axis=0)
        image_array = np.expand_dims(image_array, axis=-1)

        if verbose == True:
            print("NoBrainer:  predicting mask.")

        brain_mask_array = np.squeeze(
            model.predict(image_array, verbose=verbose))
        brain_mask_resampled = ants.copy_image_info(
            image_resampled, ants.from_numpy(brain_mask_array))
        brain_mask_image = ants.resample_image(brain_mask_resampled,
                                               image.shape,
                                               use_voxels=True,
                                               interp_type=1)

        spacing = ants.get_spacing(image)
        spacing_product = spacing[0] * spacing[1] * spacing[2]
        minimum_brain_volume = round(649933.7 / spacing_product)
        brain_mask_labeled = ants.label_clusters(brain_mask_image,
                                                 minimum_brain_volume)

        return (brain_mask_labeled)