Пример #1
0
def regression_match_image(source_image,
                           reference_image,
                           poly_order=1,
                           truncate=True):
    """
    Image intensity normalization using linear regression.

    Arguments
    ---------
    source_image : ANTsImage
        Image whose intensities are matched to the reference
        image.

    reference_image : ANTsImage
        Defines the reference intensity function.

    poly_order : integer
        Polynomial order of fit.  Default is 1 (linear fit).

    truncate : boolean
        Turns on/off the clipping of intensities.

    Returns
    -------
    ANTs image (i.e., source_image) matched to the (reference_image).

    Example
    -------
    >>> import ants
    >>> source_image = ants.image_read(ants.get_ants_data('r16'))
    >>> reference_image = ants.image_read(ants.get_ants_data('r64'))
    >>> matched_image = regression_match_image(source_image, reference_image)
    """

    if source_image.shape != reference_image.shape:
        raise ValueError("Images do not have the same dimension.")

    source_intensities = np.expand_dims((source_image.numpy()).flatten(), axis=1)
    reference_intensities = np.expand_dims((reference_image.numpy()).flatten(), axis=1)

    poly_features = PolynomialFeatures(degree=poly_order)
    source_intensities_poly = poly_features.fit_transform(source_intensities)

    model = LinearRegression()
    model.fit(source_intensities_poly, reference_intensities)

    matched_source_intensities = model.predict(source_intensities_poly)

    if truncate == True:
        min_reference_value = reference_intensities.min()
        max_reference_value = reference_intensities.max()
        matched_source_intensities[matched_source_intensities < min_reference_value] = min_reference_value
        matched_source_intensities[matched_source_intensities > max_reference_value] = max_reference_value

    matched_source_image = ants.make_image(source_image.shape, matched_source_intensities)
    matched_source_image = ants.copy_image_info(source_image,  matched_source_image)

    return(matched_source_image)
Пример #2
0
    def test_copy_image_info(self):
        for img in self.imgs:
            img2 = img.clone()
            img2.set_spacing([6.9]*img.dimension)
            img2.set_origin([6.9]*img.dimension)
            self.assertTrue(not ants.image_physical_space_consistency(img,img2))

            img3 = ants.copy_image_info(reference=img, target=img2)
            self.assertTrue(ants.image_physical_space_consistency(img,img3))
Пример #3
0
def evaluate_test_subject(scan_path: str, scan_name: str, options: dict,
                          model):
    # Read the image to analize:
    scan = ants.image_read(os.path.join(scan_path, scan_name + ".nii.gz"))
    scan_np = scan.numpy()

    # define the torch.device
    device = torch.device('cuda') if options['gpu_use'] else torch.device(
        'cpu')

    # Create the patches of the image to evaluate
    infer_patches, coordenates = get_inference_patches(
        scan_path=scan_path,
        input_data=[scan_name + "_norm.nii.gz"],
        roi=scan_name + "_ROI.nii.gz",
        patch_shape=options['patch_size'],
        step=options['sampling_step'],
        normalize=options['normalize'])

    # Get the shape of the patches
    sh = infer_patches.shape
    segmentation_pred = np.zeros((sh[0], 4, sh[2], sh[3], sh[4]))
    batch_size = options['batch_size']

    # model Evaluation
    model.eval()
    b = 0
    with torch.no_grad():
        for b in range(0, len(segmentation_pred), batch_size):
            x = torch.tensor(infer_patches[b:b + batch_size]).to(device)
            pred = model(x)
            # save the result back from GPU to CPU --> numpy
            segmentation_pred[b:b + batch_size] = pred.cpu().numpy()

    # reconstruct image takes the inferred patches, the patches coordenates and the image size as inputs
    all_probs = np.zeros(scan_np.shape + (4, ))
    for i in range(4):
        all_probs[:, :, :, i] = reconstruct_image(segmentation_pred[:, i],
                                                  coordenates, scan.shape)

    segmented = np.argmax(all_probs, axis=3).astype(np.uint8)

    # Create a nifti image
    segm_img = ants.from_numpy(segmented)
    segm_img = ants.copy_image_info(scan, segm_img)

    # Save the segmentation mask
    output_name = os.path.join(scan_path, scan_name + '_result.nii.gz')
    ants.image_write(segm_img, output_name)

    return segm_img
Пример #4
0
def lung_extraction(image,
                    modality="proton",
                    antsxnet_cache_directory=None,
                    verbose=False):

    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct", "proton", "protonLobes", 
        "maskLobes", and "ventilation".

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import pad_or_crop_image_to_size

    if image.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        weights_file_name = get_pretrained_network("protonLungMri",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability mask to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images_array}
        return(return_dict)

    if modality == "protonLobes" or modality == "maskLobes":
        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors = ants.image_read(spatial_priors_file_name_path)
        priors_image_list = ants.ndimage_to_list(spatial_priors)

        channel_size = 1 + len(priors_image_list)
        number_of_classification_labels = 1 + len(priors_image_list)

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification", 
            number_of_filters_at_base_layer=16, number_of_layers=4,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",))

        if modality == "protonLobes":
            penultimate_layer = unet_model.layers[-2].output
            outputs2 = Conv3D(filters=1,
                            kernel_size=(1, 1, 1),
                            activation='sigmoid',
                            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)
            unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2])
            weights_file_name = get_pretrained_network("protonLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network("maskLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)
        warped_array = warped_image.numpy()
        if modality == "protonLobes":
            warped_array = (warped_array - warped_array.mean()) / warped_array.std()
        else:
            warped_array[warped_array != 0] = 1
       
        batchX = np.zeros((1, *warped_array.shape, channel_size))
        batchX[0,:,:,:,0] = warped_array
        for i in range(len(priors_image_list)):
            batchX[0,:,:,:,i+1] = priors_image_list[i].numpy()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            if modality == "protonLobes":
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))
            else:
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability images to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        if modality == "protonLobes":
            whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]),
                origin=origin, spacing=spacing, direction=direction)
            whole_lung_mask = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), whole_lung_mask, image)

            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array,
                           'whole_lung_mask_image' : whole_lung_mask}
            return(return_dict)
        else:
            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array}
            return(return_dict)


    elif modality == "ct":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess CT image.")

        def closest_simplified_direction_matrix(direction):
            closest = np.floor(np.abs(direction) + 0.5)
            closest[direction < 0] *= -1.0
            return closest

        simplified_direction = closest_simplified_direction_matrix(image.direction)

        reference_image_size = (128, 128, 128)

        ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0)
        ct_preprocessed[ct_preprocessed < -1000] = -1000
        ct_preprocessed[ct_preprocessed > 400] = 400
        ct_preprocessed.set_direction(simplified_direction)
        ct_preprocessed.set_origin((0, 0, 0))
        ct_preprocessed.set_spacing((1, 1, 1))

        ################################
        #
        # Reorient image
        #
        ################################

        reference_image = ants.make_image(reference_image_size,
                                          voxval=0,
                                          spacing=(1, 1, 1),
                                          origin=(0, 0, 0),
                                          direction=np.identity(3))
        center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1))
        center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1))
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_reference), translation=translation)
        ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) /
            (ct_preprocessed.max() - ct_preprocessed.min()))
        ct_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor")
        ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) /
            (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5

        ################################
        #
        # Build models and load weights
        #
        ################################

        if verbose == True:
            print("Build model and load weights.")

        weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left lung", "right lung", "airways")
        number_of_classification_labels = len(classes)

        luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors")))
        for i in range(len(luna16_priors)):
            luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True)
        channel_size = len(luna16_priors) + 1

        unet_model = create_unet_model_3d((*reference_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification",
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5, additional_options=("attentionGating",))
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("Prediction.")

        batchX = np.zeros((1, *reference_image_size, channel_size))
        batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy()

        for i in range(len(luna16_priors)):
            batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(number_of_classification_labels):
            if verbose == True:
                print("Reconstructing image", classes[i])
            probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]),
                origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing,
                direction=ct_preprocessed_warped.direction)
            probability_image = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed)
            probability_image = ants.resample_image(probability_image,
               resample_params=image.shape, use_voxels=True, interp_type=0)
            probability_image = ants.copy_image_info(image, probability_image)
            probability_images.append(probability_image)

        image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images}
        return(return_dict)

    elif modality == "ventilation":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess ventilation image.")

        template_size = (256, 256)

        image_modalities = ("Ventilation",)
        channel_size = len(image_modalities)

        preprocessed_image = (image - image.mean()) / image.std()
        ants.set_direction(preprocessed_image, np.identity(3))

        ################################
        #
        # Build models and load weights
        #
        ################################

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs=1, mode='sigmoid',
            number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0,
            convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2),
            weight_decay=0)

        if verbose == True:
            print("Whole lung mask: retrieving model weights.")

        weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        spacing = ants.get_spacing(preprocessed_image)
        dimensions_to_predict = (spacing.index(max(spacing)),)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = ventilation_slice.numpy()
                slice_count += 1

        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        probability_image = ants.image_clone(image) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)

            prediction_per_dimension = prediction[which_batch_slices,:,:,0]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                image.shape))
            probability_image = probability_image + (prediction_image - probability_image) / (d + 1)

            current_start_slice = current_end_slice + 1

        return(probability_image)

    else:
        return ValueError("Unrecognized modality.")
Пример #5
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                do_preprocessing=True,
                                use_ensemble=True,
                                use_axial_slices_only=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    use_axial_slices_only : boolean
        If True, use original implementation which was trained on axial slices.
        If False, use ANTsXNet variant implementation which applies the slice-by-slice
        models to all 3 dimensions and averages the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import crop_image_center
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    flair_preprocessed = flair
    if do_preprocessing == True:
        flair_preprocessing = preprocess_brain_image(
            flair,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=False,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        flair_preprocessed = flair_preprocessing["preprocessed_image"]

    number_of_channels = 1
    if t1 is not None:
        t1_preprocessed = t1
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(
                t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]
        number_of_channels = 2

    ################################
    #
    # Estimate mask
    #
    ################################

    brain_mask = None
    if verbose == True:
        print("Estimating brain mask.")
    if t1 is not None:
        brain_mask = brain_extraction(t1, modality="t1")
    else:
        brain_mask = brain_extraction(flair, modality="flair")

    reference_image = ants.make_image((200, 200, 200),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))

    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, flair_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, t1_preprocessed, reference_image)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_flair = flair_preprocessed_warped[brain_mask_warped > 0].mean()
    std_flair = flair_preprocessed_warped[brain_mask_warped > 0].std()
    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair

    if number_of_channels == 2:
        mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
        std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_models.append(
            create_sysu_media_unet_model_2d((200, 200, number_of_channels)))
        unet_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]
    if use_axial_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, 200, 200, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), (200, 200))
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), (200, 200))
                batchX[slice_count, :, :, 1] = t1_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(batchX, verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(batchX, verbose=verbose)
    prediction /= number_of_models

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice + 1

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average, flair)

    return (probability_image)
Пример #6
0
def el_bicho(ventilation_image,
             mask,
             use_coarse_slices_only=True,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform functional lung segmentation using hyperpolarized gases.

    https://pubmed.ncbi.nlm.nih.gov/30195415/

    Arguments
    ---------
    ventilation_image : ANTsImage
        input ventilation image.

    mask : ANTsImage
        input mask.

    use_coarse_slices_only : boolean
        If True, apply network only in the dimension of greatest slice thickness.
        If False, apply to all dimensions and average the results.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Ventilation segmentation and corresponding probability images

    Example
    -------
    >>> image = ants.image_read("ventilation.nii.gz")
    >>> mask = ants.image_read("mask.nii.gz")
    >>> lung_seg = el_bicho(image, mask, use_coarse_slices=True, verbose=False)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size

    if ventilation_image.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if ventilation_image.shape != mask.shape:
        raise ValueError(
            "Ventilation image and mask size are not the same size.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess image
    #
    ################################

    template_size = (256, 256)
    classes = (0, 1, 2, 3, 4)
    number_of_classification_labels = len(classes)

    image_modalities = ("Ventilation", "Mask")
    channel_size = len(image_modalities)

    preprocessed_image = (ventilation_image -
                          ventilation_image.mean()) / ventilation_image.std()
    ants.set_direction(preprocessed_image, np.identity(3))

    mask_identity = ants.image_clone(mask)
    ants.set_direction(mask_identity, np.identity(3))

    ################################
    #
    # Build models and load weights
    #
    ################################

    unet_model = create_unet_model_2d(
        (*template_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        number_of_layers=4,
        number_of_filters_at_base_layer=32,
        dropout_rate=0.0,
        convolution_kernel_size=(3, 3),
        deconvolution_kernel_size=(2, 2),
        weight_decay=1e-5,
        additional_options=("attentionGating"))

    if verbose == True:
        print("El Bicho: retrieving model weights.")

    weights_file_name = get_pretrained_network(
        "elBicho", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    spacing = ants.get_spacing(preprocessed_image)
    dimensions_to_predict = (spacing.index(max(spacing)), )
    if use_coarse_slices_only == False:
        dimensions_to_predict = list(range(3))

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += preprocessed_image.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            ventilation_slice = pad_or_crop_image_to_size(
                ants.slice_image(preprocessed_image, dimensions_to_predict[d],
                                 i), template_size)
            batchX[slice_count, :, :, 0] = ventilation_slice.numpy()

            mask_slice = pad_or_crop_image_to_size(
                ants.slice_image(mask_identity, dimensions_to_predict[d], i),
                template_size)
            batchX[slice_count, :, :, 1] = mask_slice.numpy()

            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_model.predict(batchX, verbose=verbose)

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    probability_images = list()
    for l in range(number_of_classification_labels):
        probability_images.append(ants.image_clone(mask) * 0)

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + preprocessed_image.shape[
            dimensions_to_predict[d]] - 1
        which_batch_slices = range(current_start_slice, current_end_slice)

        for l in range(number_of_classification_labels):
            prediction_per_dimension = prediction[which_batch_slices, :, :, l]
            prediction_array = np.transpose(
                np.squeeze(prediction_per_dimension),
                permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(
                ventilation_image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                          ventilation_image.shape))
            probability_images[l] = probability_images[l] + (
                prediction_image - probability_images[l]) / (d + 1)

        current_start_slice = current_end_slice + 1

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], mask * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], mask * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), mask * 0 + 1)[0]

    return_dict = {
        'segmentation_image': segmentation_image,
        'probability_images': probability_images
    }
    return (return_dict)
Пример #7
0
def deep_flash(t1,
               t2=None,
               do_preprocessing=True,
               use_rank_intensity=True,
               antsxnet_cache_directory=None,
               verbose=False):
    """
    Hippocampal/Enthorhinal segmentation using "Deep Flash"

    Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using
    labels from Mike Yassa's lab

    https://faculty.sites.uci.edu/myassa/

    The labeling is as follows:
    Label 0 :  background
    Label 5 :  left aLEC
    Label 6 :  right aLEC
    Label 7 :  left pMEC
    Label 8 :  right pMEC
    Label 9 :  left perirhinal
    Label 10:  right perirhinal
    Label 11:  left parahippocampal
    Label 12:  right parahippocampal
    Label 13:  left DG/CA2/CA3/CA4
    Label 14:  right DG/CA2/CA3/CA4
    Label 15:  left CA1
    Label 16:  right CA1
    Label 17:  left subiculum
    Label 18:  right subiculum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * affine registration to the "deep flash" template.
    which is performed on the input images if do_preprocessing = True.

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    t2 : ANTsImage
        Optional 3-D T2-weighted brain image.  If specified, it is assumed to be
        pre-aligned to the t1.

    do_preprocessing : boolean
        See description above.

    use_rank_intensity : boolean
        If false, use histogram matching with cropped template ROI.  Otherwise,
        use a rank intensity transform on the cropped ROI.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label and foreground.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_flash(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import brain_extraction

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Options temporarily taken from the user
    #
    ################################

    # use_hierarchical_parcellation : boolean
    #     If True, use u-net model with additional outputs of the medial temporal lobe
    #     region, hippocampal, and entorhinal/perirhinal/parahippocampal regions.  Otherwise
    #     the only additional output is the medial temporal lobe.
    #
    # use_contralaterality : boolean
    #     Use both hemispherical models to also predict the corresponding contralateral
    #     segmentation and use both sets of priors to produce the results.  Mainly used
    #     for debugging.

    use_hierarchical_parcellation = True
    use_contralaterality = True

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    t1_mask = None
    t1_preprocessed_flipped = None
    t1_template = ants.image_read(
        get_antsxnet_data("deepFlashTemplateT1SkullStripped"))
    template_transforms = None
    if do_preprocessing:

        if verbose == True:
            print("Preprocessing T1.")

        # Brain extraction
        probability_mask = brain_extraction(
            t1_preprocessed,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0)
        t1_preprocessed = t1_preprocessed * t1_mask

        # Do bias correction
        t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed,
                                                        t1_mask,
                                                        shrink_factor=4,
                                                        verbose=verbose)

        # Warp to template
        registration = ants.registration(
            fixed=t1_template,
            moving=t1_preprocessed,
            type_of_transform="antsRegistrationSyNQuickRepro[a]",
            verbose=verbose)
        template_transforms = dict(fwdtransforms=registration['fwdtransforms'],
                                   invtransforms=registration['invtransforms'])
        t1_preprocessed = registration['warpedmovout']

    if use_contralaterality:
        t1_preprocessed_array = t1_preprocessed.numpy()
        t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0)
        t1_preprocessed_flipped = ants.from_numpy(
            t1_preprocessed_array_flipped,
            origin=t1_preprocessed.origin,
            spacing=t1_preprocessed.spacing,
            direction=t1_preprocessed.direction)

    t2_preprocessed = t2
    t2_preprocessed_flipped = None
    t2_template = None
    if t2 is not None:
        t2_template = ants.image_read(
            get_antsxnet_data("deepFlashTemplateT2SkullStripped"))
        t2_template = ants.copy_image_info(t1_template, t2_template)
        if do_preprocessing:

            if verbose == True:
                print("Preprocessing T2.")

            # Brain extraction
            t2_preprocessed = t2_preprocessed * t1_mask

            # Do bias correction
            t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed,
                                                            t1_mask,
                                                            shrink_factor=4,
                                                            verbose=verbose)

            # Warp to template
            t2_preprocessed = ants.apply_transforms(
                fixed=t1_template,
                moving=t2_preprocessed,
                transformlist=template_transforms['fwdtransforms'],
                verbose=verbose)

        if use_contralaterality:
            t2_preprocessed_array = t2_preprocessed.numpy()
            t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array,
                                                    axis=0)
            t2_preprocessed_flipped = ants.from_numpy(
                t2_preprocessed_array_flipped,
                origin=t2_preprocessed.origin,
                spacing=t2_preprocessed.spacing,
                direction=t2_preprocessed.direction)

    probability_images = list()
    labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
    image_size = (64, 64, 96)

    ################################
    #
    # Process left/right in split networks
    #
    ################################

    ################################
    #
    # Download spatial priors
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)
    for i in range(len(priors_image_list)):
        priors_image_list[i] = ants.copy_image_info(t1_preprocessed,
                                                    priors_image_list[i])

    labels_left = labels[1::2]
    priors_image_left_list = priors_image_list[1::2]
    probability_images_left = list()
    foreground_probability_images_left = list()
    lower_bound_left = (76, 74, 56)
    upper_bound_left = (140, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                    upper_bound_left)
    origin_left = tmp_cropped.origin

    spacing = tmp_cropped.spacing
    direction = tmp_cropped.direction

    t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left,
                                             upper_bound_left)
    t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min(
    )) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0
    t2_template_roi_left = None
    if t2_template is not None:
        t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left,
                                                 upper_bound_left)
        t2_template_roi_left = (t2_template_roi_left -
                                t2_template_roi_left.min()) / (
                                    t2_template_roi_left.max() -
                                    t2_template_roi_left.min()) * 2.0 - 1.0

    labels_right = labels[2::2]
    priors_image_right_list = priors_image_list[2::2]
    probability_images_right = list()
    foreground_probability_images_right = list()
    lower_bound_right = (20, 74, 56)
    upper_bound_right = (84, 138, 152)
    tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                    upper_bound_right)
    origin_right = tmp_cropped.origin

    t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right,
                                              upper_bound_right)
    t1_template_roi_right = (
        t1_template_roi_right - t1_template_roi_right.min()
    ) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0
    t2_template_roi_right = None
    if t2_template is not None:
        t2_template_roi_right = ants.crop_indices(t2_template,
                                                  lower_bound_right,
                                                  upper_bound_right)
        t2_template_roi_right = (t2_template_roi_right -
                                 t2_template_roi_right.min()) / (
                                     t2_template_roi_right.max() -
                                     t2_template_roi_right.min()) * 2.0 - 1.0

    ################################
    #
    # Create model
    #
    ################################

    channel_size = 1 + len(labels_left)
    if t2 is not None:
        channel_size += 1

    number_of_classification_labels = 1 + len(labels_left)

    unet_model = create_unet_model_3d(
        (*image_size, channel_size),
        number_of_outputs=number_of_classification_labels,
        mode="classification",
        number_of_filters=(32, 64, 96, 128, 256),
        convolution_kernel_size=(3, 3, 3),
        deconvolution_kernel_size=(2, 2, 2),
        dropout_rate=0.0,
        weight_decay=0)

    penultimate_layer = unet_model.layers[-2].output

    # medial temporal lobe
    output1 = Conv3D(
        filters=1,
        kernel_size=(1, 1, 1),
        activation='sigmoid',
        kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

    if use_hierarchical_parcellation:

        # EC, perirhinal, and parahippo.
        output2 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        # Hippocampus
        output3 = Conv3D(
            filters=1,
            kernel_size=(1, 1, 1),
            activation='sigmoid',
            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)

        unet_model = Model(
            inputs=unet_model.input,
            outputs=[unet_model.output, output1, output2, output3])
    else:
        unet_model = Model(inputs=unet_model.input,
                           outputs=[unet_model.output, output1])

    ################################
    #
    # Left:  build model and load weights
    #
    ################################

    network_name = 'deepFlashLeftT1'
    if t2 is not None:
        network_name = 'deepFlashLeftBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (left).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Left:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (left).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left,
                                   upper_bound_left)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_left, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_left, upper_bound_left)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_left, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left,
                                       upper_bound_left)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_left, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_left, upper_bound_left)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_left, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_left_list)):
        cropped_prior = ants.crop_indices(priors_image_left_list[i],
                                          lower_bound_left, upper_bound_left)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_left))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_left)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_left, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                probability_images_left.append(probability_image)
            else:  # flipped
                probability_images_right.append(probability_image)

    ################################
    #
    # Left:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_left, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                foreground_probability_images_left.append(probability_image)
            else:
                foreground_probability_images_right.append(probability_image)

    ################################
    #
    # Right:  build model and load weights
    #
    ################################

    network_name = 'deepFlashRightT1'
    if t2 is not None:
        network_name = 'deepFlashRightBoth'

    if use_hierarchical_parcellation:
        network_name += "Hierarchical"

    if use_rank_intensity:
        network_name += "_ri"

    if verbose:
        print("DeepFlash: retrieving model weights (right).")
    weights_file_name = get_pretrained_network(
        network_name, antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Right:  do prediction and normalize to native space
    #
    ################################

    if verbose:
        print("Prediction (right).")

    batchX = None
    if use_contralaterality:
        batchX = np.zeros((2, *image_size, channel_size))
    else:
        batchX = np.zeros((1, *image_size, channel_size))

    t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right,
                                   upper_bound_right)
    if use_rank_intensity:
        t1_cropped = ants.rank_intensity(t1_cropped)
    else:
        t1_cropped = ants.histogram_match_image(t1_cropped,
                                                t1_template_roi_right, 255, 64,
                                                False)
    batchX[0, :, :, :, 0] = t1_cropped.numpy()
    if use_contralaterality:
        t1_cropped = ants.crop_indices(t1_preprocessed_flipped,
                                       lower_bound_right, upper_bound_right)
        if use_rank_intensity:
            t1_cropped = ants.rank_intensity(t1_cropped)
        else:
            t1_cropped = ants.histogram_match_image(t1_cropped,
                                                    t1_template_roi_right, 255,
                                                    64, False)
        batchX[1, :, :, :, 0] = t1_cropped.numpy()
    if t2 is not None:
        t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right,
                                       upper_bound_right)
        if use_rank_intensity:
            t2_cropped = ants.rank_intensity(t2_cropped)
        else:
            t2_cropped = ants.histogram_match_image(t2_cropped,
                                                    t2_template_roi_right, 255,
                                                    64, False)
        batchX[0, :, :, :, 1] = t2_cropped.numpy()
        if use_contralaterality:
            t2_cropped = ants.crop_indices(t2_preprocessed_flipped,
                                           lower_bound_right,
                                           upper_bound_right)
            if use_rank_intensity:
                t2_cropped = ants.rank_intensity(t2_cropped)
            else:
                t2_cropped = ants.histogram_match_image(
                    t2_cropped, t2_template_roi_right, 255, 64, False)
            batchX[1, :, :, :, 1] = t2_cropped.numpy()

    for i in range(len(priors_image_right_list)):
        cropped_prior = ants.crop_indices(priors_image_right_list[i],
                                          lower_bound_right, upper_bound_right)
        for j in range(batchX.shape[0]):
            batchX[j, :, :, :, i +
                   (channel_size - len(labels_right))] = cropped_prior.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    for i in range(1 + len(labels_right)):
        for j in range(predicted_data[0].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]),
                origin=origin_right, spacing=spacing, direction=direction)
            if i > 0:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0)
            else:
                probability_image = ants.decrop_image(probability_image,
                                                      t1_preprocessed * 0 + 1)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    probability_images_right[i] = (
                        probability_images_right[i] + probability_image) / 2
                else:
                    probability_images_right.append(probability_image)
            else:  # flipped
                probability_images_left[i] = (probability_images_left[i] +
                                              probability_image) / 2

    ################################
    #
    # Right:  do prediction of mtl, hippocampal, and ec regions and normalize to native space
    #
    ################################

    for i in range(1, len(predicted_data)):
        for j in range(predicted_data[i].shape[0]):
            probability_image = \
                ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]),
                origin=origin_right, spacing=spacing, direction=direction)
            probability_image = ants.decrop_image(probability_image,
                                                  t1_preprocessed * 0)

            if j == 1:  # flipped
                probability_array_flipped = np.flip(probability_image.numpy(),
                                                    axis=0)
                probability_image = ants.from_numpy(
                    probability_array_flipped,
                    origin=probability_image.origin,
                    spacing=probability_image.spacing,
                    direction=probability_image.direction)

            if do_preprocessing:
                probability_image = ants.apply_transforms(
                    fixed=t1,
                    moving=probability_image,
                    transformlist=template_transforms['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose)

            if j == 0:  # not flipped
                if use_contralaterality:
                    foreground_probability_images_right[
                        i - 1] = (foreground_probability_images_right[i - 1] +
                                  probability_image) / 2
                else:
                    foreground_probability_images_right.append(
                        probability_image)
            else:
                foreground_probability_images_left[
                    i - 1] = (foreground_probability_images_left[i - 1] +
                              probability_image) / 2

    ################################
    #
    # Combine priors
    #
    ################################

    probability_background_image = ants.image_clone(t1) * 0
    for i in range(1, len(probability_images_left)):
        probability_background_image += probability_images_left[i]
    for i in range(1, len(probability_images_right)):
        probability_background_image += probability_images_right[i]

    probability_images.append(probability_background_image * -1 + 1)
    for i in range(1, len(probability_images_left)):
        probability_images.append(probability_images_left[i])
        probability_images.append(probability_images_right[i])

    ################################
    #
    # Convert probability images to segmentation
    #
    ################################

    # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    # segmentation_matrix = np.argmax(image_matrix, axis=0)
    # segmentation_image = ants.matrix_to_images(
    #     np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    image_matrix = ants.image_list_to_matrix(
        probability_images[1:(len(probability_images))], t1 * 0 + 1)
    background_foreground_matrix = np.stack([
        ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1),
        np.expand_dims(np.sum(image_matrix, axis=0), axis=0)
    ])
    foreground_matrix = np.argmax(background_foreground_matrix, axis=0)
    segmentation_matrix = (np.argmax(image_matrix, axis=0) +
                           1) * foreground_matrix
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    relabeled_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        relabeled_image[segmentation_image == i] = labels[i]

    foreground_probability_images = list()
    for i in range(len(foreground_probability_images_left)):
        foreground_probability_images.append(
            foreground_probability_images_left[i] +
            foreground_probability_images_right[i])

    return_dict = None
    if use_hierarchical_parcellation:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0],
            'other_region_probability_image':
            foreground_probability_images[1],
            'hippocampal_probability_image':
            foreground_probability_images[2]
        }
    else:
        return_dict = {
            'segmentation_image':
            relabeled_image,
            'probability_images':
            probability_images,
            'medial_temporal_lobe_probability_image':
            foreground_probability_images[0]
        }

    return (return_dict)
Пример #8
0
def claustrum_segmentation(t1,
                           do_preprocessing=True,
                           use_ensemble=True,
                           antsxnet_cache_directory=None,
                           verbose=False):
    """
    Claustrum segmentation

    Described here:

        https://arxiv.org/abs/2008.03465

    with the implementation available at:

        https://github.com/hongweilibran/claustrum_multi_view


    Arguments
    ---------
    t1 : ANTsImage
        input 3-D T1 brain image.

    do_preprocessing : boolean
        perform n4 bias correction.

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Claustrum segmentation probability image

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> probability_mask = claustrum_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import brain_extraction
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import pad_or_crop_image_to_size

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (180, 180)

    ################################
    #
    # Preprocess images
    #
    ################################

    number_of_channels = 1
    t1_preprocessed = ants.image_clone(t1)
    brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        brain_mask = t1_preprocessing["brain_mask"]

    reference_image = ants.make_image((170, 256, 256),
                                      voxval=1,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = ants.get_center_of_mass(reference_image)
    center_of_mass_image = ants.get_center_of_mass(brain_mask)
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    t1_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm, t1_preprocessed, reference_image)
    brain_mask_warped = ants.threshold_image(
        ants.apply_ants_transform_to_image(xfrm, brain_mask, reference_image),
        0.5, 1.1, 1, 0)

    ################################
    #
    # Gaussian normalize intensity based on brain mask
    #
    ################################

    mean_t1 = t1_preprocessed_warped[brain_mask_warped > 0].mean()
    std_t1 = t1_preprocessed_warped[brain_mask_warped > 0].std()
    t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    t1_preprocessed_warped = t1_preprocessed_warped * brain_mask_warped

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("Claustrum:  retrieving axial model weights.")

    unet_axial_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_axial_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_axial_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_axial_models[i].load_weights(weights_file_name)

    if verbose == True:
        print("Claustrum:  retrieving coronal model weights.")

    unet_coronal_models = list()
    for i in range(number_of_models):
        weights_file_name = get_pretrained_network(
            "claustrum_coronal_" + str(i),
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_coronal_models.append(
            create_sysu_media_unet_model_2d((*image_size, number_of_channels),
                                            anatomy="claustrum"))
        unet_coronal_models[i].load_weights(weights_file_name)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [1, 2]

    batch_coronal_X = np.zeros(
        (t1_preprocessed_warped.shape[1], *image_size, number_of_channels))
    batch_axial_X = np.zeros(
        (t1_preprocessed_warped.shape[2], *image_size, number_of_channels))

    for d in range(len(dimensions_to_predict)):
        number_of_slices = t1_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            t1_slice = pad_or_crop_image_to_size(
                ants.slice_image(t1_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            if dimensions_to_predict[d] == 1:
                batch_coronal_X[i, :, :, 0] = np.rot90(t1_slice.numpy(), k=-1)
            else:
                batch_axial_X[i, :, :, 0] = np.rot90(t1_slice.numpy())

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Coronal prediction.")

    prediction_coronal = unet_coronal_models[0].predict(batch_coronal_X,
                                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_coronal += unet_coronal_models[i].predict(
                batch_coronal_X, verbose=verbose)
    prediction_coronal /= number_of_models

    for i in range(t1_preprocessed_warped.shape[1]):
        prediction_coronal[i, :, :, 0] = np.rot90(
            np.squeeze(prediction_coronal[i, :, :, 0]))

    if verbose == True:
        print("Axial prediction.")

    prediction_axial = unet_axial_models[0].predict(batch_axial_X,
                                                    verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction_axial += unet_axial_models[i].predict(batch_axial_X,
                                                             verbose=verbose)
    prediction_axial /= number_of_models

    for i in range(t1_preprocessed_warped.shape[2]):
        prediction_axial[i, :, :,
                         0] = np.rot90(np.squeeze(prediction_axial[i, :, :,
                                                                   0]),
                                       k=-1)

    if verbose == True:
        print("Restack image and transform back to native space.")

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(t1_preprocessed_warped) * 0

    for d in range(len(dimensions_to_predict)):
        which_batch_slices = range(
            t1_preprocessed_warped.shape[dimensions_to_predict[d]])
        prediction_per_dimension = None
        if dimensions_to_predict[d] == 1:
            prediction_per_dimension = prediction_coronal[
                which_batch_slices, :, :, :]
        else:
            prediction_per_dimension = prediction_axial[
                which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            t1_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      t1_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        t1) * ants.threshold_image(brain_mask, 0.5, 1, 1, 0)

    return (probability_image)
Пример #9
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             do_slicewise=True,
             antsxnet_cache_directory=None,
             verbose=False):

    """
    Perform White matter hypterintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{doPreprocessing = TRUE}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    do_slicewise : boolean
        apply 2-D modal along direction of maximal slice thickness.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if t1.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if do_slicewise == False:

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=True,
                template="croppedMni152",
                template_transform_type="AffineFast",
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
                moving=flair_preprocessing["preprocessed_image"],
                transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
            flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        ################################
        #
        # Build model and load weights
        #
        ################################

        patch_size = (112, 112, 112)
        stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                        t1_preprocessed.shape[1] - patch_size[1],
                        t1_preprocessed.shape[2] - patch_size[2])

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_3d((*patch_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=False, add_attention_gating=True)

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("ew_david:  prediction.")

        batchX = np.zeros((8, *patch_size, channel_size))

        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,0] = t1_patches

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,1] = flair_patches

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(len(labels)):
            print("Reconstructing image", classes[i])
            reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
                domain_image=t1_preprocessed, stride_length=stride_length)

            if do_preprocessing == True:
                probability_images.append(ants.apply_transforms(fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']['invtransforms'],
                    whichtoinvert=[True], interpolator="linear", verbose=verbose))
            else:
                probability_images.append(reconstructed_image)

        return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = flair_preprocessing["preprocessed_image"]

        resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0)
            t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0)

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (256, 256)

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 32, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3), deconvolution_kernel_size = (2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=True, add_attention_gating=True)

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationSlicewiseWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        use_coarse_slices_only = True

        spacing = ants.get_spacing(flair_preprocessed)
        dimensions_to_predict = (spacing.index(max(spacing)),)
        if use_coarse_slices_only == False:
            dimensions_to_predict = list(range(3))

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += flair_preprocessed.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = flair_preprocessed.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = flair_slice.numpy()

                t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,1] = t1_slice.numpy()

                slice_count += 1


        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        prediction_image_average = ants.image_clone(flair_preprocessed) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + flair_preprocessed.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)
            prediction_per_dimension = prediction[which_batch_slices,:,:,1]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(flair_preprocessed,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                flair_preprocessed.shape))
            prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1)

            current_start_slice = current_end_slice + 1

        if do_resampling:
            prediction_image_average = ants.resample_image_to_target(prediction_image_average, flair)

        return(prediction_image_average)
Пример #10
0
def apply_super_resolution_model_to_image(
    image,
    model,
    target_range=(-127.5, 127.5),
    batch_size=32,
    regression_order=None,
    verbose=False,
):

    """
    Apply a pretrained deep back projection model for super resolution.
    Helper function for applying a pretrained deep back projection model.
    Apply a patch-wise trained network to perform super-resolution. Can be applied
    to variable sized inputs. Warning: This function may be better used on CPU
    unless the GPU can accommodate the full image size. Warning 2: The global
    intensity range (min to max) of the output will match the input where the
    range is taken over all channels.

    Arguments
    ---------
    image : ANTs image
        input image.

    model : keras object or string
        pretrained keras model or filename.

    target_range : 2-element tuple
        a tuple or array defining the (min, max) of the input image
        (e.g., -127.5, 127.5).  Output images will be scaled back to original
        intensity. This range should match the mapping used in the training
        of the network.

    batch_size : integer
        Batch size used for the prediction call.

    regression_order : integer
        If specified, Apply the function regression_match_image with
        poly_order=regression_order.

    verbose : boolean
        If True, show status messages.

    Returns
    -------
    Super-resolution image upscaled to resolution specified by the network.

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> image_sr = apply_super_resolution_model_to_image(image, get_pretrained_network("dbpn4x"))
    """
    tflite_flag = False
    channel_axis = 0
    if K.image_data_format() == "channels_last":
        channel_axis = -1

    if target_range[0] > target_range[1]:
        target_range = target_range[::-1]

    start_time = time.time()
    if isinstance(model, str):
        if path.isfile(model):
            if verbose:
                print("Load model.")
            if path.splitext(model)[1] == '.tflite':
                interpreter = tf.lite.Interpreter(model)
                interpreter.allocate_tensors()
                input_details = interpreter.get_input_details()
                output_details = interpreter.get_output_details()
                shape_length = len(interpreter.get_input_details()[0]['shape'])
            else:    
                model = load_model(model)
                shape_length = len(model.input_shape)

            if verbose:
                elapsed_time = time.time() - start_time
                print("  (elapsed time: ", elapsed_time, ")")
        else:
            raise ValueError("Model not found.")
    else:
        shape_length = len(model.input_shape)

    
    if shape_length < 4 | shape_length > 5:
        raise ValueError("Unexpected input shape.")
    else:
        if shape_length == 5 & image.dimension != 3:
            raise ValueError("Expecting 3D input for this model.")
        elif shape_length == 4 & image.dimension != 2:
            raise ValueError("Expecting 2D input for this model.")

    if channel_axis == -1:
        channel_axis < shape_length
    if  tflite_flag:
        channel_size = interpreter.get_input_details()[0]['shape'][channel_axis]
    else:
        channel_size = model.input_shape[channel_axis]

    if channel_size != image.components:
        raise ValueError(
            "Channel size of model",
            str(channel_size),
            "does not match ncomponents=",
            str(image.components),
            "of the input image.",
        )

    image_patches = extract_image_patches(
        image,
        patch_size=image.shape,
        max_number_of_patches=1,
        stride_length=image.shape,
        return_as_array=True,
    )
    if image.components == 1:
        image_patches = np.expand_dims(image_patches, axis=-1)

    image_patches = image_patches - image_patches.min()
    image_patches = (
        image_patches / image_patches.max() * (target_range[1] - target_range[0])
        + target_range[0]
    )

    if verbose:
        print("Prediction")

    start_time = time.time()

    if  tflite_flag:
        image_patches = image_patches.astype('float32')
        interpreter.set_tensor(input_details[0]['index'], image_patches)
        interpreter.invoke()
        out = interpreter.tensor(output_details[0]['index'])
        prediction = out()
    else:
        prediction = model.predict(image_patches, batch_size=batch_size)

    if verbose:
        elapsed_time = time.time() - start_time
        print("  (elapsed time: ", elapsed_time, ")")

    if verbose:
        print("Reconstruct intensities")

    intensity_range = image.range()
    prediction = prediction - prediction.min()
    prediction = (
        prediction / prediction.max() * (intensity_range[1] - intensity_range[0])
        + intensity_range[0]
    )

    def slice_array_channel(input_array, slice, channel_axis=-1):
        if channel_axis == 0:
            if shape_length == 4:
                return input_array[slice, :, :, :]
            else:
                return input_array[slice, :, :, :, :]
        else:
            if shape_length == 4:
                return input_array[:, :, :, slice]
            else:
                return input_array[:, :, :, :, slice]

    expansion_factor = np.asarray(prediction.shape) / np.asarray(image_patches.shape)
    if channel_axis == 0:
        FIXME

    expansion_factor = expansion_factor[1 : (len(expansion_factor) - 1)]

    if verbose:
        print("ExpansionFactor:", str(expansion_factor))

    if image.components == 1:
        image_array = slice_array_channel(prediction, 0, channel_axis)
        prediction_image = ants.make_image(
            (np.asarray(image.shape) * np.asarray(expansion_factor)).astype(int),
            image_array,
        )
        if regression_order is not None:
            reference_image = ants.resample_image_to_target(image, prediction_image)
            prediction_image = regression_match_image(
                prediction_image, reference_image, poly_order=regression_order
            )
    else:
        image_component_list = list()
        for k in range(image.components):
            image_array = slice_array_channel(prediction, k, channel_axis)
            image_component_list.append(
                ants.make_image(
                    (np.asarray(image.shape) * np.asarray(expansion_factor)).astype(
                        int
                    ),
                    image_array,
                )
            )
        prediction_image = ants.merge_channels(image_component_list)

    prediction_image = ants.copy_image_info(image, prediction_image)
    ants.set_spacing(
        prediction_image,
        tuple(np.asarray(image.spacing) / np.asarray(expansion_factor)),
    )

    return prediction_image
Пример #11
0
def deep_atropos(t1,
                 do_preprocessing=True,
                 use_spatial_priors=1,
                 antsxnet_cache_directory=None,
                 verbose=False):
    """
    Six-tissue segmentation.

    Perform Atropos-style six tissue segmentation using deep learning.

    The labeling is as follows:
    Label 0 :  background
    Label 1 :  CSF
    Label 2 :  gray matter
    Label 3 :  white matter
    Label 4 :  deep gray matter
    Label 5 :  brain stem
    Label 6 :  cerebellum

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    use_spatial_priors : integer
        Use MNI spatial tissue priors (0 or 1).  Currently, only '0' (no priors) and '1'
        (cerebellar prior only) are the only two options.  Default is 1.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = deep_atropos(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            brain_extraction_modality="t1",
            template="croppedMni152",
            template_transform_type="antsRegistrationSyNQuickRepro[a]",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Build model and load weights
    #
    ################################

    patch_size = (112, 112, 112)
    stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                     t1_preprocessed.shape[1] - patch_size[1],
                     t1_preprocessed.shape[2] - patch_size[2])

    classes = ("background", "csf", "gray matter", "white matter",
               "deep gray matter", "brain stem", "cerebellum")

    mni_priors = None
    channel_size = 1
    if use_spatial_priors != 0:
        mni_priors = ants.ndimage_to_list(
            ants.image_read(
                get_antsxnet_data(
                    "croppedMni152Priors",
                    antsxnet_cache_directory=antsxnet_cache_directory)))
        for i in range(len(mni_priors)):
            mni_priors[i] = ants.copy_image_info(t1_preprocessed,
                                                 mni_priors[i])
        channel_size = 2

    unet_model = create_unet_model_3d((*patch_size, channel_size),
                                      number_of_outputs=len(classes),
                                      mode="classification",
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      additional_options=("attentionGating"))

    if verbose == True:
        print("DeepAtropos:  retrieving model weights.")

    weights_file_name = ''
    if use_spatial_priors == 0:
        weights_file_name = get_pretrained_network(
            "sixTissueOctantBrainSegmentation",
            antsxnet_cache_directory=antsxnet_cache_directory)
    elif use_spatial_priors == 1:
        weights_file_name = get_pretrained_network(
            "sixTissueOctantBrainSegmentationWithPriors1",
            antsxnet_cache_directory=antsxnet_cache_directory)
    else:
        raise ValueError("use_spatial_priors must be a 0 or 1")
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    t1_preprocessed = (t1_preprocessed -
                       t1_preprocessed.mean()) / t1_preprocessed.std()
    image_patches = extract_image_patches(t1_preprocessed,
                                          patch_size=patch_size,
                                          max_number_of_patches="all",
                                          stride_length=stride_length,
                                          return_as_array=True)
    batchX = np.zeros((*image_patches.shape, channel_size))
    batchX[:, :, :, :, 0] = image_patches
    if channel_size > 1:
        prior_patches = extract_image_patches(mni_priors[6],
                                              patch_size=patch_size,
                                              max_number_of_patches="all",
                                              stride_length=stride_length,
                                              return_as_array=True)
        batchX[:, :, :, :, 1] = prior_patches

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    probability_images = list()
    for i in range(len(classes)):
        if verbose == True:
            print("Reconstructing image", classes[i])
        reconstructed_image = reconstruct_image_from_patches(
            predicted_data[:, :, :, :, i],
            domain_image=t1_preprocessed,
            stride_length=stride_length)

        if do_preprocessing == True:
            probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            probability_images.append(reconstructed_image)

    image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    return_dict = {
        'segmentation_image': segmentation_image,
        'probability_images': probability_images
    }
    return (return_dict)
Пример #12
0
def sysu_media_wmh_segmentation(flair,
                                t1=None,
                                use_ensemble=True,
                                antsxnet_cache_directory=None,
                                verbose=False):
    """
    Perform WMH segmentation using the winning submission in the MICCAI
    2017 challenge by the sysu_media team using FLAIR or T1/FLAIR.  The
    MICCAI challenge is discussed in

    https://pubmed.ncbi.nlm.nih.gov/30908194/

    with the sysu_media's team entry is discussed in

     https://pubmed.ncbi.nlm.nih.gov/30125711/

    with the original implementation available here:

    https://github.com/hongweilibran/wmh_ibbmTum

    The original implementation used global thresholding as a quick
    brain extraction approach.  Due to possible generalization difficulties,
    we leave such post-processing steps to the user.  For brain or white
    matter masking see functions brain_extraction or deep_atropos,
    respectively.

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    use_ensemble : boolean
        check whether to use all 3 sets of weights.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_sysu_media_unet_model_2d
    from ..utilities import get_pretrained_network
    from ..utilities import pad_or_crop_image_to_size
    from ..utilities import preprocess_brain_image
    from ..utilities import binary_dice_coefficient

    if flair.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_size = (200, 200)

    ################################
    #
    # Preprocess images
    #
    ################################

    def closest_simplified_direction_matrix(direction):
        closest = (np.abs(direction) + 0.5).astype(int).astype(float)
        closest[direction < 0] *= -1.0
        return closest

    simplified_direction = closest_simplified_direction_matrix(flair.direction)

    flair_preprocessing = preprocess_brain_image(
        flair,
        truncate_intensity=None,
        brain_extraction_modality=None,
        do_bias_correction=False,
        do_denoising=False,
        antsxnet_cache_directory=antsxnet_cache_directory,
        verbose=verbose)
    flair_preprocessed = flair_preprocessing["preprocessed_image"]
    flair_preprocessed.set_direction(simplified_direction)
    flair_preprocessed.set_origin((0, 0, 0))
    flair_preprocessed.set_spacing((1, 1, 1))
    number_of_channels = 1

    t1_preprocessed = None
    if t1 is not None:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality=None,
            do_bias_correction=False,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing["preprocessed_image"]
        t1_preprocessed.set_direction(simplified_direction)
        t1_preprocessed.set_origin((0, 0, 0))
        t1_preprocessed.set_spacing((1, 1, 1))
        number_of_channels = 2

    ################################
    #
    # Reorient images
    #
    ################################

    reference_image = ants.make_image((256, 256, 256),
                                      voxval=0,
                                      spacing=(1, 1, 1),
                                      origin=(0, 0, 0),
                                      direction=np.identity(3))
    center_of_mass_reference = np.floor(
        ants.get_center_of_mass(reference_image * 0 + 1))
    center_of_mass_image = np.floor(
        ants.get_center_of_mass(flair_preprocessed))
    translation = np.asarray(center_of_mass_image) - np.asarray(
        center_of_mass_reference)
    xfrm = ants.create_ants_transform(
        transform_type="Euler3DTransform",
        center=np.asarray(center_of_mass_reference),
        translation=translation)
    flair_preprocessed_warped = ants.apply_ants_transform_to_image(
        xfrm,
        flair_preprocessed,
        reference_image,
        interpolation="nearestneighbor")
    crop_image = ants.image_clone(flair_preprocessed) * 0 + 1
    crop_image_warped = ants.apply_ants_transform_to_image(
        xfrm, crop_image, reference_image, interpolation="nearestneighbor")
    flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped,
                                                crop_image_warped, 1)

    if t1 is not None:
        t1_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm,
            t1_preprocessed,
            reference_image,
            interpolation="nearestneighbor")
        t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped,
                                                 crop_image_warped, 1)

    ################################
    #
    # Gaussian normalize intensity
    #
    ################################

    mean_flair = flair_preprocessed.mean()
    std_flair = flair_preprocessed.std()
    if number_of_channels == 2:
        mean_t1 = t1_preprocessed.mean()
        std_t1 = t1_preprocessed.std()

    flair_preprocessed_warped = (flair_preprocessed_warped -
                                 mean_flair) / std_flair
    if number_of_channels == 2:
        t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1

    ################################
    #
    # Build models and load weights
    #
    ################################

    number_of_models = 1
    if use_ensemble == True:
        number_of_models = 3

    if verbose == True:
        print("White matter hyperintensity:  retrieving model weights.")

    unet_models = list()
    for i in range(number_of_models):
        if number_of_channels == 1:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairOnlyModel" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network(
                "sysuMediaWmhFlairT1Model" + str(i),
                antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model = create_sysu_media_unet_model_2d(
            (*image_size, number_of_channels))
        unet_loss = binary_dice_coefficient(smoothing_factor=1.)
        unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4),
                           loss=unet_loss)
        unet_model.load_weights(weights_file_name)
        unet_models.append(unet_model)

    ################################
    #
    # Extract slices
    #
    ################################

    dimensions_to_predict = [2]

    total_number_of_slices = 0
    for d in range(len(dimensions_to_predict)):
        total_number_of_slices += flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

    batchX = np.zeros(
        (total_number_of_slices, *image_size, number_of_channels))

    slice_count = 0
    for d in range(len(dimensions_to_predict)):
        number_of_slices = flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]

        if verbose == True:
            print("Extracting slices for dimension ", dimensions_to_predict[d],
                  ".")

        for i in range(number_of_slices):
            flair_slice = pad_or_crop_image_to_size(
                ants.slice_image(flair_preprocessed_warped,
                                 dimensions_to_predict[d], i), image_size)
            batchX[slice_count, :, :, 0] = flair_slice.numpy()
            if number_of_channels == 2:
                t1_slice = pad_or_crop_image_to_size(
                    ants.slice_image(t1_preprocessed_warped,
                                     dimensions_to_predict[d], i), image_size)
                batchX[slice_count, :, :, 1] = t1_slice.numpy()
            slice_count += 1

    ################################
    #
    # Do prediction and then restack into the image
    #
    ################################

    if verbose == True:
        print("Prediction.")

    prediction = unet_models[0].predict(np.transpose(batchX,
                                                     axes=(0, 2, 1, 3)),
                                        verbose=verbose)
    if number_of_models > 1:
        for i in range(1, number_of_models, 1):
            prediction += unet_models[i].predict(np.transpose(batchX,
                                                              axes=(0, 2, 1,
                                                                    3)),
                                                 verbose=verbose)
    prediction /= number_of_models
    prediction = np.transpose(prediction, axes=(0, 2, 1, 3))

    permutations = list()
    permutations.append((0, 1, 2))
    permutations.append((1, 0, 2))
    permutations.append((1, 2, 0))

    prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0

    current_start_slice = 0
    for d in range(len(dimensions_to_predict)):
        current_end_slice = current_start_slice + flair_preprocessed_warped.shape[
            dimensions_to_predict[d]]
        which_batch_slices = range(current_start_slice, current_end_slice)
        prediction_per_dimension = prediction[which_batch_slices, :, :, :]
        prediction_array = np.transpose(np.squeeze(prediction_per_dimension),
                                        permutations[dimensions_to_predict[d]])
        prediction_image = ants.copy_image_info(
            flair_preprocessed_warped,
            pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                                      flair_preprocessed_warped.shape))
        prediction_image_average = prediction_image_average + (
            prediction_image - prediction_image_average) / (d + 1)
        current_start_slice = current_end_slice

    probability_image = ants.apply_ants_transform_to_image(
        ants.invert_ants_transform(xfrm), prediction_image_average,
        flair_preprocessed)
    probability_image = ants.copy_image_info(flair, probability_image)

    return (probability_image)
Пример #13
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             which_model="sysu",
             which_axes=2,
             number_of_simulations=0,
             sd_affine=0.01,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform White matter hyperintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * intensity truncation,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{do_preprocessing = True}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction, intensity truncation, brain extraction.

    which_model : string
        one of:
            * "sysu" -- same as the original sysu network (without site specific preprocessing),
            * "sysu-ri" -- same as "sysu" but using ranked intensity scaling for input images,
            * "sysuWithAttention" -- "sysu" with attention gating,
            * "sysuWithAttentionAndSite" -- "sysu" with attention gating with site branch (see "sysuWithSite"),
            * "sysuPlus" -- "sysu" with attention gating and nn-Unet activation,
            * "sysuPlusSeg" -- "sysuPlus" with deep_atropos segmentation in an additional channel, and
            * "sysuWithSite" -- "sysu" with global pooling on encoding channels to predict "site".
            * "sysuPlusSegWithSite" -- "sysuPlusSeg" combined with "sysuWithSite"
        In addition to both modalities, all models have T1-only and flair-only variants except
        for "sysuPlusSeg" (which only has a T1-only variant) or "sysu-ri" (which has neither single
        modality variant).

    which_axes : string or scalar or tuple/vector
        apply 2-D model to 1 or more axes.  In addition to a scalar
        or vector, e.g., which_axes = (0, 2), one can use "max" for the
        axis with maximum anisotropy (default) or "all" for all axes.

    number_of_simulations : integer
        Number of random affine perturbations to transform the input.

    sd_affine : float
        Define the standard deviation of the affine transformation parameter.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import deep_atropos
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import randomly_transform_image_data
    from ..utilities import pad_or_crop_image_to_size

    do_t1_only = False
    do_flair_only = False

    if flair is None and t1 is not None:
        do_t1_only = True
    elif flair is not None and t1 is None:
        do_flair_only = True

    use_t1_segmentation = False
    if "Seg" in which_model:
        if do_flair_only:
            raise ValueError("Segmentation requires T1.")
        else:
            use_t1_segmentation = True

    if use_t1_segmentation and do_preprocessing == False:
        raise ValueError(
            "Using the t1 segmentation requires do_preprocessing=True.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    do_slicewise = True

    if do_slicewise == False:

        raise ValueError("Not available.")

        # ################################
        # #
        # # Preprocess images
        # #
        # ################################

        # t1_preprocessed = t1
        # t1_preprocessing = None
        # if do_preprocessing == True:
        #     t1_preprocessing = preprocess_brain_image(t1,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         template="croppedMni152",
        #         template_transform_type="antsRegistrationSyNQuickRepro[a]",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        # flair_preprocessed = flair
        # if do_preprocessing == True:
        #     flair_preprocessing = preprocess_brain_image(flair,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
        #         moving=flair_preprocessing["preprocessed_image"],
        #         transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
        #     flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        # ################################
        # #
        # # Build model and load weights
        # #
        # ################################

        # patch_size = (112, 112, 112)
        # stride_length = (t1_preprocessed.shape[0] - patch_size[0],
        #                 t1_preprocessed.shape[1] - patch_size[1],
        #                 t1_preprocessed.shape[2] - patch_size[2])

        # classes = ("background", "wmh" )
        # number_of_classification_labels = len(classes)
        # labels = (0, 1)

        # image_modalities = ("T1", "FLAIR")
        # channel_size = len(image_modalities)

        # unet_model = create_unet_model_3d((*patch_size, channel_size),
        #     number_of_outputs = number_of_classification_labels,
        #     number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
        #     convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
        #     weight_decay = 1e-5, additional_options=("attentionGating"))

        # weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
        #     antsxnet_cache_directory=antsxnet_cache_directory)
        # unet_model.load_weights(weights_file_name)

        # ################################
        # #
        # # Do prediction and normalize to native space
        # #
        # ################################

        # if verbose == True:
        #     print("ew_david:  prediction.")

        # batchX = np.zeros((8, *patch_size, channel_size))

        # t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        # t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,0] = t1_patches

        # flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        # flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,1] = flair_patches

        # predicted_data = unet_model.predict(batchX, verbose=verbose)

        # probability_images = list()
        # for i in range(len(labels)):
        #     print("Reconstructing image", classes[i])
        #     reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
        #         domain_image=t1_preprocessed, stride_length=stride_length)

        #     if do_preprocessing == True:
        #         probability_images.append(ants.apply_transforms(fixed=t1,
        #             moving=reconstructed_image,
        #             transformlist=t1_preprocessing['template_transforms']['invtransforms'],
        #             whichtoinvert=[True], interpolator="linear", verbose=verbose))
        #     else:
        #         probability_images.append(reconstructed_image)

        # return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        use_rank_intensity_scaling = False
        if "-ri" in which_model:
            use_rank_intensity_scaling = True

        t1_preprocessed = None
        t1_preprocessing = None
        brain_mask = None
        if t1 is not None:
            if do_preprocessing == True:
                t1_preprocessing = preprocess_brain_image(
                    t1,
                    truncate_intensity=(0.01, 0.995),
                    brain_extraction_modality="t1",
                    do_bias_correction=False,
                    do_denoising=False,
                    antsxnet_cache_directory=antsxnet_cache_directory,
                    verbose=verbose)
                brain_mask = ants.threshold_image(
                    t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                t1_preprocessed = t1_preprocessing["preprocessed_image"]

        t1_segmentation = None
        if use_t1_segmentation:
            atropos_seg = deep_atropos(t1,
                                       do_preprocessing=True,
                                       verbose=verbose)
            t1_segmentation = atropos_seg['segmentation_image']

        flair_preprocessed = None
        if flair is not None:
            flair_preprocessed = flair
            if do_preprocessing == True:
                if brain_mask is None:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=(0.01, 0.995),
                        brain_extraction_modality="flair",
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                    brain_mask = ants.threshold_image(
                        flair_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                else:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=None,
                        brain_extraction_modality=None,
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                flair_preprocessed = flair_preprocessing["preprocessed_image"]

        if t1_preprocessed is not None:
            t1_preprocessed = t1_preprocessed * brain_mask
        if flair_preprocessed is not None:
            flair_preprocessed = flair_preprocessed * brain_mask

        if t1_preprocessed is not None:
            resampling_params = list(ants.get_spacing(t1_preprocessed))
        else:
            resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            if flair_preprocessed is not None:
                flair_preprocessed = ants.resample_image(flair_preprocessed,
                                                         resampling_params,
                                                         use_voxels=False,
                                                         interp_type=0)
            if t1_preprocessed is not None:
                t1_preprocessed = ants.resample_image(t1_preprocessed,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=0)
            if t1_segmentation is not None:
                t1_segmentation = ants.resample_image(t1_segmentation,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=1)
            if brain_mask is not None:
                brain_mask = ants.resample_image(brain_mask,
                                                 resampling_params,
                                                 use_voxels=False,
                                                 interp_type=1)

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (208, 208)

        image_modalities = ("T1", "FLAIR")
        if do_flair_only:
            image_modalities = ("FLAIR", )
        elif do_t1_only:
            image_modalities = ("T1", )
        if use_t1_segmentation:
            image_modalities = (*image_modalities, "T1Seg")
        channel_size = len(image_modalities)

        unet_model = None
        if which_model == "sysu" or which_model == "sysu-ri":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuWithAttention":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithAttentionAndSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuPlusSegWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        else:
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = None
        if which_model == "sysu" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysu",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu-ri" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuRankedIntensity",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttention",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlus",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSeg",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            raise ValueError(
                "Incorrect model specification or image combination.")

        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Data augmentation and extract slices
        #
        ################################

        wmh_probability_image = None
        if t1 is not None:
            wmh_probability_image = ants.image_clone(t1_preprocessed) * 0
        else:
            wmh_probability_image = ants.image_clone(flair_preprocessed) * 0

        wmh_site = np.array([0, 0, 0])

        data_augmentation = None
        if number_of_simulations > 0:
            if do_flair_only:
                data_augmentation = randomly_transform_image_data(
                    reference_image=flair_preprocessed,
                    input_image_list=[[flair_preprocessed]],
                    number_of_simulations=number_of_simulations,
                    transform_type='affine',
                    sd_affine=sd_affine,
                    input_image_interpolator='linear')
            elif do_t1_only:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')
            else:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')

        dimensions_to_predict = list((0, ))
        if which_axes == "max":
            spacing = ants.get_spacing(wmh_probability_image)
            dimensions_to_predict = (spacing.index(max(spacing)), )
        elif which_axes == "all":
            dimensions_to_predict = list(range(3))
        else:
            if isinstance(which_axes, int):
                dimensions_to_predict = list((which_axes, ))
            else:
                dimensions_to_predict = list(which_axes)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += wmh_probability_image.shape[
                dimensions_to_predict[d]]

        batchX = np.zeros(
            (total_number_of_slices, *template_size, channel_size))

        for n in range(number_of_simulations + 1):

            batch_flair = flair_preprocessed
            batch_t1 = t1_preprocessed
            batch_t1_segmentation = t1_segmentation
            batch_brain_mask = brain_mask

            if n > 0:

                if do_flair_only:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")

                elif do_t1_only:
                    batch_t1 = data_augmentation['simulated_images'][n - 1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        t1_preprocessed,
                        interpolation="nearestneighbor")
                else:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_t1 = data_augmentation['simulated_images'][n - 1][1]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")
                if use_t1_segmentation:
                    batch_t1_segmentation = data_augmentation[
                        'simulated_segmentation_images'][n - 1]

            if use_rank_intensity_scaling:
                if batch_t1 is not None:
                    batch_t1 = ants.rank_intensity(batch_t1,
                                                   batch_brain_mask) - 0.5
                if batch_flair is not None:
                    batch_flair = ants.rank_intensity(flair_preprocessed,
                                                      batch_brain_mask) - 0.5
            else:
                if batch_t1 is not None:
                    batch_t1 = (batch_t1 -
                                batch_t1[batch_brain_mask == 1].mean()
                                ) / batch_t1[batch_brain_mask == 1].std()
                if batch_flair is not None:
                    batch_flair = (
                        batch_flair -
                        batch_flair[batch_brain_mask == 1].mean()
                    ) / batch_flair[batch_brain_mask == 1].std()

            slice_count = 0
            for d in range(len(dimensions_to_predict)):

                number_of_slices = None
                if batch_t1 is not None:
                    number_of_slices = batch_t1.shape[dimensions_to_predict[d]]
                else:
                    number_of_slices = batch_flair.shape[
                        dimensions_to_predict[d]]

                if verbose == True:
                    print("Extracting slices for dimension ",
                          dimensions_to_predict[d])

                for i in range(number_of_slices):

                    brain_mask_slice = pad_or_crop_image_to_size(
                        ants.slice_image(batch_brain_mask,
                                         dimensions_to_predict[d], i),
                        template_size)

                    channel_count = 0
                    if batch_flair is not None:
                        flair_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_flair,
                                             dimensions_to_predict[d], i),
                            template_size)
                        flair_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = flair_slice.numpy()
                        channel_count += 1
                    if batch_t1 is not None:
                        t1_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_slice.numpy()
                        channel_count += 1
                    if t1_segmentation is not None:
                        t1_segmentation_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1_segmentation,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_segmentation_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_segmentation_slice.numpy(
                               ) / 6 - 0.5

                    slice_count += 1

            ################################
            #
            # Do prediction and then restack into the image
            #
            ################################

            if verbose == True:
                if n == 0:
                    print("Prediction")
                else:
                    print("Prediction (simulation " + str(n) + ")")

            prediction = unet_model.predict(batchX, verbose=verbose)

            permutations = list()
            permutations.append((0, 1, 2))
            permutations.append((1, 0, 2))
            permutations.append((1, 2, 0))

            prediction_image_average = ants.image_clone(
                wmh_probability_image) * 0

            current_start_slice = 0
            for d in range(len(dimensions_to_predict)):
                current_end_slice = current_start_slice + wmh_probability_image.shape[
                    dimensions_to_predict[d]]
                which_batch_slices = range(current_start_slice,
                                           current_end_slice)
                if isinstance(prediction, list):
                    prediction_per_dimension = prediction[0][
                        which_batch_slices, :, :, 0]
                else:
                    prediction_per_dimension = prediction[
                        which_batch_slices, :, :, 0]
                prediction_array = np.transpose(
                    np.squeeze(prediction_per_dimension),
                    permutations[dimensions_to_predict[d]])
                prediction_image = ants.copy_image_info(
                    wmh_probability_image,
                    pad_or_crop_image_to_size(
                        ants.from_numpy(prediction_array),
                        wmh_probability_image.shape))
                prediction_image_average = prediction_image_average + (
                    prediction_image - prediction_image_average) / (d + 1)
                current_start_slice = current_end_slice

            wmh_probability_image = wmh_probability_image + (
                prediction_image_average - wmh_probability_image) / (n + 1)
            if isinstance(prediction, list):
                wmh_site = wmh_site + (np.mean(prediction[1], axis=0) -
                                       wmh_site) / (n + 1)

        if do_resampling:
            if t1 is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, t1)
            if flair is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, flair)

        if isinstance(prediction, list):
            return ([wmh_probability_image, wmh_site])
        else:
            return (wmh_probability_image)
Пример #14
0
def hippmapp3r_segmentation(t1,
                            do_preprocessing=True,
                            antsxnet_cache_directory=None,
                            verbose=False):
    """
    Perform HippMapp3r (hippocampal) segmentation described in

     https://www.ncbi.nlm.nih.gov/pubmed/31609046

    with models and architecture ported from

    https://github.com/mgoubran/HippMapp3r

    Additional documentation and attribution resources found at

    https://hippmapp3r.readthedocs.io/en/latest/

    Preprocessing consists of:
       * n4 bias correction and
       * brain extraction
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        input image

    do_preprocessing : boolean
        See description above.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs labeled hippocampal image.

    Example
    -------
    >>> mask = hippmapp3r_segmentation(t1)
    """

    from ..architectures import create_hippmapp3r_unet_model_3d
    from ..utilities import preprocess_brain_image
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if verbose == True:
        print("*************  Preprocessing  ***************")
        print("")

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=None,
            brain_extraction_modality="t1",
            template=None,
            do_bias_correction=True,
            do_denoising=False,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    if verbose == True:
        print("*************  Initial stage segmentation  ***************")
        print("")

    # Normalize to mprage_hippmapp3r space
    if verbose == True:
        print("    HippMapp3r: template normalization.")

    template_file_name_path = get_antsxnet_data(
        "mprage_hippmapp3r", antsxnet_cache_directory=antsxnet_cache_directory)
    template_image = ants.image_read(template_file_name_path)

    registration = ants.registration(
        fixed=template_image,
        moving=t1_preprocessed,
        type_of_transform="antsRegistrationSyNQuickRepro[t]",
        verbose=verbose)
    image = registration['warpedmovout']
    transforms = dict(fwdtransforms=registration['fwdtransforms'],
                      invtransforms=registration['invtransforms'])

    # Threshold at 10th percentile of non-zero voxels in "robust range (fslmaths)"
    if verbose == True:
        print("    HippMapp3r: threshold.")

    image_array = image.numpy()
    image_robust_range = np.quantile(image_array[np.where(image_array != 0)],
                                     (0.02, 0.98))
    threshold_value = 0.10 * (image_robust_range[1] -
                              image_robust_range[0]) + image_robust_range[0]
    thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0,
                                            1)
    thresholded_image = image * thresholded_mask

    # Standardize image
    if verbose == True:
        print("    HippMapp3r: standardize.")

    mean_image = np.mean(thresholded_image[thresholded_mask == 1])
    sd_image = np.std(thresholded_image[thresholded_mask == 1])
    image_normalized = (image - mean_image) / sd_image
    image_normalized = image_normalized * thresholded_mask

    # Trim and resample image
    if verbose == True:
        print("    HippMapp3r: trim and resample to (160, 160, 128).")

    image_cropped = ants.crop_image(image_normalized, thresholded_mask, 1)
    shape_initial_stage = (160, 160, 128)
    image_resampled = ants.resample_image(image_cropped,
                                          shape_initial_stage,
                                          use_voxels=True,
                                          interp_type=1)

    if verbose == True:
        print("    HippMapp3r: generate first network and download weights.")

    model_initial_stage = create_hippmapp3r_unet_model_3d(
        (*shape_initial_stage, 1), do_first_network=True)

    initial_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rInitial", antsxnet_cache_directory=antsxnet_cache_directory)
    model_initial_stage.load_weights(initial_stage_weights_file_name)

    if verbose == True:
        print("    HippMapp3r: prediction.")

    data_initial_stage = np.expand_dims(image_resampled.numpy(), axis=0)
    data_initial_stage = np.expand_dims(data_initial_stage, axis=-1)
    mask_array = model_initial_stage.predict(data_initial_stage,
                                             verbose=verbose)
    mask_image_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(np.squeeze(mask_array)))
    mask_image = ants.resample_image(mask_image_resampled,
                                     image.shape,
                                     use_voxels=True,
                                     interp_type=0)
    mask_image[mask_image >= 0.5] = 1
    mask_image[mask_image < 0.5] = 0

    #########################################
    #
    # Perform refined (stage 2) segmentation
    #

    if verbose == True:
        print("")
        print("")
        print("*************  Refine stage segmentation  ***************")
        print("")

    mask_array = np.squeeze(mask_array)
    centroid_indices = np.where(mask_array == 1)
    centroid = np.zeros((3, ))
    centroid[0] = centroid_indices[0].mean()
    centroid[1] = centroid_indices[1].mean()
    centroid[2] = centroid_indices[2].mean()

    shape_refine_stage = (112, 112, 64)
    lower = (np.floor(centroid - 0.5 * np.array(shape_refine_stage)) -
             1).astype(int)
    upper = (lower + np.array(shape_refine_stage)).astype(int)

    image_trimmed = ants.crop_indices(image_resampled, lower.astype(int),
                                      upper.astype(int))

    if verbose == True:
        print("    HippMapp3r: generate second network and download weights.")

    model_refine_stage = create_hippmapp3r_unet_model_3d(
        (*shape_refine_stage, 1), do_first_network=False)

    refine_stage_weights_file_name = get_pretrained_network(
        "hippMapp3rRefine", antsxnet_cache_directory=antsxnet_cache_directory)
    model_refine_stage.load_weights(refine_stage_weights_file_name)

    data_refine_stage = np.expand_dims(image_trimmed.numpy(), axis=0)
    data_refine_stage = np.expand_dims(data_refine_stage, axis=-1)

    if verbose == True:
        print("    HippMapp3r: Monte Carlo iterations (SpatialDropout).")

    number_of_mci_iterations = 30
    prediction_refine_stage = np.zeros(shape_refine_stage)
    for i in range(number_of_mci_iterations):
        tf.random.set_seed(i)
        if verbose == True:
            print("        Monte Carlo iteration", i + 1, "out of",
                  number_of_mci_iterations)
        prediction_refine_stage = \
            (np.squeeze(model_refine_stage.predict(data_refine_stage, verbose=verbose)) + \
             i * prediction_refine_stage ) / (i + 1)

    prediction_refine_stage_array = np.zeros(image_resampled.shape)
    prediction_refine_stage_array[lower[0]:upper[0], lower[1]:upper[1],
                                  lower[2]:upper[2]] = prediction_refine_stage
    probability_mask_refine_stage_resampled = ants.copy_image_info(
        image_resampled, ants.from_numpy(prediction_refine_stage_array))

    segmentation_image_resampled = ants.label_clusters(ants.threshold_image(
        probability_mask_refine_stage_resampled, 0.0, 0.5, 0, 1),
                                                       min_cluster_size=10)
    segmentation_image_resampled[segmentation_image_resampled > 2] = 0
    geom = ants.label_geometry_measures(segmentation_image_resampled)
    if len(geom['VolumeInMillimeters']) < 2:
        raise ValueError("Error: left and right hippocampus not found.")

    if geom['Centroid_x'][0] < geom['Centroid_x'][1]:
        segmentation_image_resampled[segmentation_image_resampled == 1] = 3
        segmentation_image_resampled[segmentation_image_resampled == 2] = 1
        segmentation_image_resampled[segmentation_image_resampled == 3] = 2

    segmentation_image = ants.apply_transforms(
        fixed=t1,
        moving=segmentation_image_resampled,
        transformlist=transforms['invtransforms'],
        whichtoinvert=[True],
        interpolator="genericLabel",
        verbose=verbose)

    return (segmentation_image)
Пример #15
0
def brain_extraction(image,
                     modality="t1",
                     antsxnet_cache_directory=None,
                     verbose=False):
    """
    Perform brain extraction using U-net and ANTs-based training data.  "NoBrainer"
    is also possible where brain extraction uses U-net and FreeSurfer training data
    ported from the

    https://github.com/neuronets/nobrainer-models

    Arguments
    ---------
    image : ANTsImage
        input image (or list of images for multi-modal scenarios).

    modality : string
        Modality image type.  Options include:
            * "t1": T1-weighted MRI---ANTs-trained.  Update from "t1v0".
            * "t1v0":  T1-weighted MRI---ANTs-trained.
            * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk.
            * "t1combined": Brian's combination of "t1" and "t1nobrainer".  One can also specify
                            "t1combined[X]" where X is the morphological radius.  X = 12 by default.
            * "flair": FLAIR MRI.
            * "t2": T2 MRI.
            * "bold": 3-D BOLD MRI.
            * "fa": Fractional anisotropy.
            * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner.
            * "t1infant": T1-w infant MRI h/t Martin Styner.
            * "t2infant": T2-w infant MRI h/t Martin Styner.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    ANTs probability brain mask image.

    Example
    -------
    >>> probability_brain_mask = brain_extraction(brain_image, modality="t1")
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..architectures import create_nobrainer_unet_model_3d

    classes = ("background", "brain")
    number_of_classification_labels = len(classes)

    channel_size = 1
    if isinstance(image, list):
        channel_size = len(image)

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    input_images = list()
    if channel_size == 1:
        input_images.append(image)
    else:
        input_images = image

    if input_images[0].dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if "t1combined" in modality:

        brain_extraction_t1 = brain_extraction(
            image,
            modality="t1",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_mask = ants.iMath_get_largest_component(
            ants.threshold_image(brain_extraction_t1, 0.5, 10000))

        # Need to change with voxel resolution
        morphological_radius = 12
        if '[' in modality and ']' in modality:
            morphological_radius = int(modality.split("[")[1].split("]")[0])

        brain_extraction_t1nobrainer = brain_extraction(
            image * ants.iMath_MD(brain_mask, radius=morphological_radius),
            modality="t1nobrainer",
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        brain_extraction_combined = ants.iMath_fill_holes(
            ants.iMath_get_largest_component(brain_extraction_t1nobrainer *
                                             brain_mask))

        brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(
            brain_mask, morphological_radius) + brain_mask

        return (brain_extraction_combined)

    if modality != "t1nobrainer":

        #####################
        #
        # ANTs-based
        #
        #####################

        weights_file_name_prefix = None

        if modality == "t1v0":
            weights_file_name_prefix = "brainExtraction"
        elif modality == "t1":
            weights_file_name_prefix = "brainExtractionT1"
        elif modality == "t2":
            weights_file_name_prefix = "brainExtractionT2"
        elif modality == "flair":
            weights_file_name_prefix = "brainExtractionFLAIR"
        elif modality == "bold":
            weights_file_name_prefix = "brainExtractionBOLD"
        elif modality == "fa":
            weights_file_name_prefix = "brainExtractionFA"
        elif modality == "t1t2infant":
            weights_file_name_prefix = "brainExtractionInfantT1T2"
        elif modality == "t1infant":
            weights_file_name_prefix = "brainExtractionInfantT1"
        elif modality == "t2infant":
            weights_file_name_prefix = "brainExtractionInfantT2"
        else:
            raise ValueError("Unknown modality type.")

        weights_file_name = get_pretrained_network(
            weights_file_name_prefix,
            antsxnet_cache_directory=antsxnet_cache_directory)

        if verbose == True:
            print("Brain extraction:  retrieving template.")
        reorient_template_file_name_path = get_antsxnet_data(
            "S_template3", antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)
        resampled_image_size = reorient_template.shape

        if modality == "t1":
            classes = ("background", "head", "brain")
            number_of_classification_labels = len(classes)

        unet_model = create_unet_model_3d(
            (*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4,
            number_of_filters_at_base_layer=8,
            dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3),
            deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Brain extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template)
        center_of_mass_image = ants.get_center_of_mass(input_images[0])
        translation = np.asarray(center_of_mass_image) - np.asarray(
            center_of_mass_template)
        xfrm = ants.create_ants_transform(
            transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template),
            translation=translation)

        batchX = np.zeros((1, *resampled_image_size, channel_size))
        for i in range(len(input_images)):
            warped_image = ants.apply_ants_transform_to_image(
                xfrm, input_images[i], reorient_template)
            warped_array = warped_image.numpy()
            batchX[0, :, :, :, i] = (warped_array -
                                     warped_array.mean()) / warped_array.std()

        if verbose == True:
            print("Brain extraction:  prediction and decoding.")

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        origin = reorient_template.origin
        spacing = reorient_template.spacing
        direction = reorient_template.direction

        probability_images_array = list()
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 0]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 1]),
                            origin=origin,
                            spacing=spacing,
                            direction=direction))
        if modality == "t1":
            probability_images_array.append(
                ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, 2]),
                                origin=origin,
                                spacing=spacing,
                                direction=direction))

        if verbose == True:
            print(
                "Brain extraction:  renormalize probability mask to native space."
            )
        probability_image = ants.apply_ants_transform_to_image(
            ants.invert_ants_transform(xfrm),
            probability_images_array[number_of_classification_labels - 1],
            input_images[0])

        return (probability_image)

    else:

        #####################
        #
        # NoBrainer
        #
        #####################

        if verbose == True:
            print("NoBrainer:  generating network.")

        model = create_nobrainer_unet_model_3d((None, None, None, 1))

        weights_file_name = get_pretrained_network(
            "brainExtractionNoBrainer",
            antsxnet_cache_directory=antsxnet_cache_directory)
        model.load_weights(weights_file_name)

        if verbose == True:
            print(
                "NoBrainer:  preprocessing (intensity truncation and resampling)."
            )

        image_array = image.numpy()
        image_robust_range = np.quantile(
            image_array[np.where(image_array != 0)], (0.02, 0.98))
        threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]
                                  ) + image_robust_range[0]

        thresholded_mask = ants.threshold_image(image, -10000, threshold_value,
                                                0, 1)
        thresholded_image = image * thresholded_mask

        image_resampled = ants.resample_image(thresholded_image,
                                              (256, 256, 256),
                                              use_voxels=True)
        image_array = np.expand_dims(image_resampled.numpy(), axis=0)
        image_array = np.expand_dims(image_array, axis=-1)

        if verbose == True:
            print("NoBrainer:  predicting mask.")

        brain_mask_array = np.squeeze(
            model.predict(image_array, verbose=verbose))
        brain_mask_resampled = ants.copy_image_info(
            image_resampled, ants.from_numpy(brain_mask_array))
        brain_mask_image = ants.resample_image(brain_mask_resampled,
                                               image.shape,
                                               use_voxels=True,
                                               interp_type=1)

        spacing = ants.get_spacing(image)
        spacing_product = spacing[0] * spacing[1] * spacing[2]
        minimum_brain_volume = round(649933.7 / spacing_product)
        brain_mask_labeled = ants.label_clusters(brain_mask_image,
                                                 minimum_brain_volume)

        return (brain_mask_labeled)