def mri_super_resolution(image, output_directory=None, verbose=False):
    """
    Perform super-resolution (2x) of MRI data using deep back projection network.

    Arguments
    ---------
    image : ANTsImage
        magnetic resonance image

    output_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        tempfile.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    The super-resolved image.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> image_sr = mri_super_resolution(image)
    """

    from ..utilities import get_pretrained_network
    from ..utilities import apply_super_resolution_model_to_image
    from ..utilities import regression_match_image

    if image.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    model_and_weights_file_name = None
    if output_directory is not None:
        model_and_weights_file_name = output_directory + "/mindmapsSR_16_ANINN222_0.h5"
        if not os.path.exists(model_and_weights_file_name):
            if verbose == True:
                print("MRI super-resolution:  downloading model weights.")
            model_and_weights_file_name = get_pretrained_network(
                "mriSuperResolution", model_and_weights_file_name)
    else:
        model_and_weights_file_name = get_pretrained_network(
            "mriSuperResolution")

    model_sr = tf.keras.models.load_model(model_and_weights_file_name,
                                          compile=False)

    image_sr = apply_super_resolution_model_to_image(image,
                                                     model_sr,
                                                     target_range=(-127.5,
                                                                   127.5))
    image_sr = regression_match_image(image_sr,
                                      ants.resample_image_to_target(
                                          image, image_sr),
                                      poly_order=1)

    return image_sr
def mri_super_resolution(image, antsxnet_cache_directory=None, verbose=False):

    """
    Perform super-resolution (2x) of MRI data using deep back projection network.

    Arguments
    ---------
    image : ANTsImage
        magnetic resonance image

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    The super-resolved image.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> image_sr = mri_super_resolution(image)
    """

    from ..utilities import get_pretrained_network
    from ..utilities import apply_super_resolution_model_to_image
    from ..utilities import regression_match_image

    if image.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    model_and_weights_file_name = get_pretrained_network("mriSuperResolution", antsxnet_cache_directory=antsxnet_cache_directory)
    model_sr = tf.keras.models.load_model(model_and_weights_file_name, compile=False)

    image_sr = apply_super_resolution_model_to_image(
        image, model_sr, target_range=(-127.5, 127.5)
    )
    image_sr = regression_match_image(
        image_sr, ants.resample_image_to_target(image, image_sr), poly_order=1
    )

    return image_sr
Beispiel #3
0
    def reconstruct_to_imgsize(bids_file: ants.ANTsImage,
                               ori_shape: Tuple[int], mask_pred: np.ndarray,
                               resampled_bids: ants.ANTsImage):
        """
        Reconstruct the predicted mask to the original bids image size.
        First resize the mask to the resampled bids shape, then resample it to into the original voxel space.
        """
        resized_array = reconstruct_image(ori_shape, mask_pred)

        resized_mask = resampled_bids.new_image_like(resized_array)

        return ants.resample_image_to_target(
            image=resized_mask,
            target=bids_file,
            interp_type='nearestNeighbor',
        )
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

number_of_image_volumes = len(input_image_list)

output_image_list = list()
for i in range(number_of_image_volumes):
    print("Applying super resolution to image", i, "of",
          number_of_image_volumes)
    start_time = time.time()

    input_image = ants.iMath(input_image_list[i], "TruncateIntensity", 0.0001,
                             0.995)
    output_sr = antspynet.apply_super_resolution_model_to_image(
        input_image, model, target_range=(127.5, -127.5))
    input_image_resampled = ants.resample_image_to_target(
        input_image, output_sr)
    output_image_list.append(
        antspynet.regression_match_image(output_sr,
                                         input_image_resampled,
                                         poly_order=2))

    end_time = time.time()
    elapsed_time = end_time - start_time
    print("   (elapsed time:", elapsed_time, "seconds)")

print("Writing output image.")
if number_of_image_volumes == 1:
    ants.image_write(output_image_list[0], output_file_name)
else:
    output_image = ants.list_to_ndimage(input_image, output_image_list)
    ants.image_write(output_image, output_file_name)
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print(
    "    Refine step 4: Average monte carlo results and write probability mask image."
)
start_time = time.time()
prediction_refine_stage_array = np.zeros(image_resampled.shape)
prediction_refine_stage_array[lower[0]:upper[0], lower[1]:upper[1],
                              lower[2]:upper[2]] = prediction_refine_stage
probability_mask_refine_stage_resampled = ants.from_numpy(
    prediction_refine_stage_array,
    origin=image_resampled.origin,
    spacing=image_resampled.spacing,
    direction=image_resampled.direction)
probability_mask_refine_stage = ants.resample_image_to_target(
    probability_mask_refine_stage_resampled, image)
ants.image_write(probability_mask_refine_stage, output_file_name)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Renormalize to native space")
start_time = time.time()
probability_image = ants.apply_ants_transform_to_image(
    ants.invert_ants_transform(xfrm), probability_mask_refine_stage, image)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

print("Writing", output_file_name)
start_time = time.time()
def randomly_transform_image_data(
        reference_image,
        input_image_list,
        segmentation_image_list=None,
        number_of_simulations=10,
        transform_type='affine',
        sd_affine=0.02,
        deformation_transform_type="bspline",
        number_of_random_points=1000,
        sd_noise=10.0,
        number_of_fitting_levels=4,
        mesh_size=1,
        sd_smoothing=4.0,
        input_image_interpolator='linear',
        segmentation_image_interpolator='nearestNeighbor'):
    """
    Randomly transform image data (optional: with corresponding segmentations).

    Apply rigid, affine and/or deformable maps to an input set of training
    images.  The reference image domain defines the space in which this 
    happens.

    Arguments
    ---------
    reference_image : ANTsImage
        Defines the spatial domain for all output images.  If the input images do 
        not match the spatial domain of the reference image, we internally 
        resample the target to the reference image.  This could have unexpected 
        consequences.  Resampling to the reference domain is performed by testing 
        using ants.image_physical_space_consistency then calling 
        ants.resample_image_to_target with failure.

    input_image_list : list of lists of ANTsImages
        List of lists of input images to warp.  The internal list sets contain one 
        or more images (per subject) which are assumed to be mutually aligned.  The 
        outer list contains multiple subject lists which are randomly sampled to 
        produce output image list.
        
    segmentation_image_list : list of ANTsImages
        List of segmentation images corresponding to the input image list (optional).

    number_of_simulations : integer
        Number of output images. 

    transform_type : string
        One of the following options: "translation", "rigid", "scaleShear", "affine",
        "deformation", "affineAndDeformation".

    sd_affine : float
        Parameter dictating deviation amount from identity for random linear 
        transformations.

    deformation_transform_type : string
        "bspline" or "exponential".

    number_of_random_points : integer
        Number of displacement points for the deformation field.

    sd_noise : float
        Standard deviation of the displacement field.

    number_of_fitting_levels : integer
        Number of fitting levels (bspline deformation only).    

    mesh_size : int or n-D tuple
        Determines fitting resolution (bspline deformation only).    

    sd_smoothing : float 
        Standard deviation of the Gaussian smoothing in mm (exponential field only).

    input_image_interpolator : string
        One of the following options "linear", "gaussian", "bspline".

    segmentation_image_interpolator : string
        One of the following options "nearestNeighbor" or "genericLabel".

    Returns
    -------
    list of lists of transformed images

    Example
    -------
    >>> import ants
    >>> image1_list = list()
    >>> image1_list.append(ants.image_read(ants.get_ants_data("r16")))
    >>> image2_list = list()
    >>> image2_list.append(ants.image_read(ants.get_ants_data("r64"))) 
    >>> input_segmentations = list()
    >>> input_segmentations.append(ants.threshold_image(image1, "Otsu", 3))
    >>> input_segmentations.append(ants.threshold_image(image2, "Otsu", 3))
    >>> input_images = list()
    >>> input_images.append(image1_list)
    >>> input_images.append(image2_list)
    >>> data = antspynet.randomly_transform_image_data(image1, 
    >>>     input_images, input_segmentations, sd_affine=0.02,
    >>>     transform_type = "affineAndDeformation" )
    """
    def polar_decomposition(X):
        U, d, V = np.linalg.svd(X, full_matrices=False)
        P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
        Z = np.matmul(U, np.transpose(V))
        if np.linalg.det(Z) < 0:
            Z = -Z
        return ({"P": P, "Z": Z, "Xtilde": np.matmul(P, Z)})

    def create_random_linear_transform(image,
                                       fixed_parameters,
                                       transform_type='affine',
                                       sd_affine=1.0):
        transform = ants.create_ants_transform(
            transform_type="AffineTransform",
            precision='float',
            dimension=image.dimension)
        ants.set_ants_transform_fixed_parameters(transform, fixed_parameters)
        identity_parameters = ants.get_ants_transform_parameters(transform)
        random_epsilon = np.random.normal(loc=0,
                                          scale=sd_affine,
                                          size=len(identity_parameters))

        if transform_type == 'translation':
            random_epsilon[:(len(identity_parameters) - image.dimension)] = 0

        random_parameters = identity_parameters + random_epsilon
        random_matrix = np.reshape(
            random_parameters[:(len(identity_parameters) - image.dimension)],
            newshape=(image.dimension, image.dimension))
        decomposition = polar_decomposition(random_matrix)

        if transform_type == "rigid":
            random_matrix = decomposition['Z']
        elif transform_type == "affine":
            random_matrix = decomposition['Xtilde']
        elif transform_type == "scaleShear":
            random_matrix = decomposition['P']

        random_parameters[:(len(identity_parameters) - image.dimension)] = \
            np.reshape(random_matrix, newshape=(len(identity_parameters) - image.dimension))
        ants.set_ants_transform_parameters(transform, random_parameters)
        return (transform)

    def create_random_displacement_field_transform(
            image,
            field_type="bspline",
            number_of_random_points=1000,
            sd_noise=10.0,
            number_of_fitting_levels=4,
            mesh_size=1,
            sd_smoothing=4.0):
        displacement_field = ants.simulate_displacement_field(
            image,
            field_type=field_type,
            number_of_random_points=number_of_random_points,
            sd_noise=sd_noise,
            enforce_stationary_boundary=True,
            number_of_fitting_levels=number_of_fitting_levels,
            mesh_size=mesh_size,
            sd_smoothing=sd_smoothing)
        return (ants.transform_from_displacement_field(displacement_field))

    admissible_transforms = ("translation", "rigid", "scaleShear", "affine",
                             "affineAndDeformation", "deformation")
    if not transform_type in admissible_transforms:
        raise ValueError(
            "The specified transform is not a possible option.  Please see help menu."
        )

    # Get the fixed parameters from the reference image.

    fixed_parameters = ants.get_center_of_mass(reference_image)
    number_of_subjects = len(input_image_list)

    random_indices = np.random.choice(number_of_subjects,
                                      size=number_of_simulations,
                                      replace=True)

    simulated_image_list = list()
    simulated_segmentation_image_list = list()
    simulated_transforms = list()

    for i in range(number_of_simulations):
        single_subject_image_list = input_image_list[random_indices[i]]
        single_subject_segmentation_image = None
        if segmentation_image_list is not None:
            single_subject_segmentation_image = segmentation_image_list[
                random_indices[i]]

        if ants.image_physical_space_consistency(
                reference_image, single_subject_image_list[0]) is False:
            for j in range(len(single_subject_image_list)):
                single_subject_image_list.append(
                    ants.resample_image_to_target(
                        single_subject_image_list[j],
                        reference_image,
                        interp_type=input_image_interpolator))
            if single_subject_segmentation_image is not None:
                single_subject_segmentation_image = \
                    ants.resample_image_to_target(single_subject_segmentation_image, reference_image,
                        interp_type=segmentation_image_interpolator)

        transforms = list()

        if transform_type == 'deformation':
            deformable_transform = create_random_displacement_field_transform(
                reference_image, deformation_transform_type,
                number_of_random_points, sd_noise, number_of_fitting_levels,
                mesh_size, sd_smoothing)
            transforms.append(deformable_transform)
        elif transform_type == 'affineAndDeformation':
            deformable_transform = create_random_displacement_field_transform(
                reference_image, deformation_transform_type,
                number_of_random_points, sd_noise, number_of_fitting_levels,
                mesh_size, sd_smoothing)
            linear_transform = create_random_linear_transform(
                reference_image, fixed_parameters, 'affine', sd_affine)
            transforms.append(deformable_transform)
            transforms.append(linear_transform)
        else:
            linear_transform = create_random_linear_transform(
                reference_image, fixed_parameters, transform_type, sd_affine)
            transforms.append(linear_transform)

        simulated_transforms.append(ants.compose_ants_transforms(transforms))

        single_subject_simulated_image_list = list()
        for j in range(len(single_subject_image_list)):
            single_subject_simulated_image_list.append(
                ants.apply_ants_transform_to_image(
                    simulated_transforms[i],
                    single_subject_image_list[j],
                    reference=reference_image))

        simulated_image_list.append(single_subject_simulated_image_list)

        if single_subject_segmentation_image is not None:
            simulated_segmentation_image_list.append(
                ants.apply_ants_transform_to_image(
                    simulated_transforms[i],
                    single_subject_segmentation_image,
                    reference=reference_image))

    if segmentation_image_list is None:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_transforms': simulated_transforms
        })
    else:
        return ({
            'simulated_images': simulated_image_list,
            'simulated_segmentation_images': simulated_segmentation_image_list,
            'simulated_transforms': simulated_transforms
        })
Beispiel #7
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             do_slicewise=True,
             antsxnet_cache_directory=None,
             verbose=False):

    """
    Perform White matter hypterintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{doPreprocessing = TRUE}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    do_slicewise : boolean
        apply 2-D modal along direction of maximal slice thickness.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if t1.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if do_slicewise == False:

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=True,
                template="croppedMni152",
                template_transform_type="AffineFast",
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
                moving=flair_preprocessing["preprocessed_image"],
                transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
            flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        ################################
        #
        # Build model and load weights
        #
        ################################

        patch_size = (112, 112, 112)
        stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                        t1_preprocessed.shape[1] - patch_size[1],
                        t1_preprocessed.shape[2] - patch_size[2])

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_3d((*patch_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=False, add_attention_gating=True)

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("ew_david:  prediction.")

        batchX = np.zeros((8, *patch_size, channel_size))

        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,0] = t1_patches

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,1] = flair_patches

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(len(labels)):
            print("Reconstructing image", classes[i])
            reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
                domain_image=t1_preprocessed, stride_length=stride_length)

            if do_preprocessing == True:
                probability_images.append(ants.apply_transforms(fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']['invtransforms'],
                    whichtoinvert=[True], interpolator="linear", verbose=verbose))
            else:
                probability_images.append(reconstructed_image)

        return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = flair_preprocessing["preprocessed_image"]

        resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0)
            t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0)

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (256, 256)

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 32, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3), deconvolution_kernel_size = (2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=True, add_attention_gating=True)

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationSlicewiseWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        use_coarse_slices_only = True

        spacing = ants.get_spacing(flair_preprocessed)
        dimensions_to_predict = (spacing.index(max(spacing)),)
        if use_coarse_slices_only == False:
            dimensions_to_predict = list(range(3))

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += flair_preprocessed.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = flair_preprocessed.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = flair_slice.numpy()

                t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,1] = t1_slice.numpy()

                slice_count += 1


        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        prediction_image_average = ants.image_clone(flair_preprocessed) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + flair_preprocessed.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)
            prediction_per_dimension = prediction[which_batch_slices,:,:,1]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(flair_preprocessed,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                flair_preprocessed.shape))
            prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1)

            current_start_slice = current_end_slice + 1

        if do_resampling:
            prediction_image_average = ants.resample_image_to_target(prediction_image_average, flair)

        return(prediction_image_average)
Beispiel #8
0
 def test_resample_image_to_target_example(self):
     fi = ants.image_read(ants.get_ants_data("r16"))
     fi2mm = ants.resample_image(fi, (2, 2), use_voxels=0, interp_type=1)
     resampled = ants.resample_image_to_target(fi2mm, fi, verbose=True)
def apply_super_resolution_model_to_image(
    image,
    model,
    target_range=(-127.5, 127.5),
    batch_size=32,
    regression_order=None,
    verbose=False,
):

    """
    Apply a pretrained deep back projection model for super resolution.
    Helper function for applying a pretrained deep back projection model.
    Apply a patch-wise trained network to perform super-resolution. Can be applied
    to variable sized inputs. Warning: This function may be better used on CPU
    unless the GPU can accommodate the full image size. Warning 2: The global
    intensity range (min to max) of the output will match the input where the
    range is taken over all channels.

    Arguments
    ---------
    image : ANTs image
        input image.

    model : keras object or string
        pretrained keras model or filename.

    target_range : 2-element tuple
        a tuple or array defining the (min, max) of the input image
        (e.g., -127.5, 127.5).  Output images will be scaled back to original
        intensity. This range should match the mapping used in the training
        of the network.

    batch_size : integer
        Batch size used for the prediction call.

    regression_order : integer
        If specified, Apply the function regression_match_image with
        poly_order=regression_order.

    verbose : boolean
        If True, show status messages.

    Returns
    -------
    Super-resolution image upscaled to resolution specified by the network.

    Example
    -------
    >>> import ants
    >>> image = ants.image_read(ants.get_ants_data('r16'))
    >>> image_sr = apply_super_resolution_model_to_image(image, get_pretrained_network("dbpn4x"))
    """
    tflite_flag = False
    channel_axis = 0
    if K.image_data_format() == "channels_last":
        channel_axis = -1

    if target_range[0] > target_range[1]:
        target_range = target_range[::-1]

    start_time = time.time()
    if isinstance(model, str):
        if path.isfile(model):
            if verbose:
                print("Load model.")
            if path.splitext(model)[1] == '.tflite':
                interpreter = tf.lite.Interpreter(model)
                interpreter.allocate_tensors()
                input_details = interpreter.get_input_details()
                output_details = interpreter.get_output_details()
                shape_length = len(interpreter.get_input_details()[0]['shape'])
            else:    
                model = load_model(model)
                shape_length = len(model.input_shape)

            if verbose:
                elapsed_time = time.time() - start_time
                print("  (elapsed time: ", elapsed_time, ")")
        else:
            raise ValueError("Model not found.")
    else:
        shape_length = len(model.input_shape)

    
    if shape_length < 4 | shape_length > 5:
        raise ValueError("Unexpected input shape.")
    else:
        if shape_length == 5 & image.dimension != 3:
            raise ValueError("Expecting 3D input for this model.")
        elif shape_length == 4 & image.dimension != 2:
            raise ValueError("Expecting 2D input for this model.")

    if channel_axis == -1:
        channel_axis < shape_length
    if  tflite_flag:
        channel_size = interpreter.get_input_details()[0]['shape'][channel_axis]
    else:
        channel_size = model.input_shape[channel_axis]

    if channel_size != image.components:
        raise ValueError(
            "Channel size of model",
            str(channel_size),
            "does not match ncomponents=",
            str(image.components),
            "of the input image.",
        )

    image_patches = extract_image_patches(
        image,
        patch_size=image.shape,
        max_number_of_patches=1,
        stride_length=image.shape,
        return_as_array=True,
    )
    if image.components == 1:
        image_patches = np.expand_dims(image_patches, axis=-1)

    image_patches = image_patches - image_patches.min()
    image_patches = (
        image_patches / image_patches.max() * (target_range[1] - target_range[0])
        + target_range[0]
    )

    if verbose:
        print("Prediction")

    start_time = time.time()

    if  tflite_flag:
        image_patches = image_patches.astype('float32')
        interpreter.set_tensor(input_details[0]['index'], image_patches)
        interpreter.invoke()
        out = interpreter.tensor(output_details[0]['index'])
        prediction = out()
    else:
        prediction = model.predict(image_patches, batch_size=batch_size)

    if verbose:
        elapsed_time = time.time() - start_time
        print("  (elapsed time: ", elapsed_time, ")")

    if verbose:
        print("Reconstruct intensities")

    intensity_range = image.range()
    prediction = prediction - prediction.min()
    prediction = (
        prediction / prediction.max() * (intensity_range[1] - intensity_range[0])
        + intensity_range[0]
    )

    def slice_array_channel(input_array, slice, channel_axis=-1):
        if channel_axis == 0:
            if shape_length == 4:
                return input_array[slice, :, :, :]
            else:
                return input_array[slice, :, :, :, :]
        else:
            if shape_length == 4:
                return input_array[:, :, :, slice]
            else:
                return input_array[:, :, :, :, slice]

    expansion_factor = np.asarray(prediction.shape) / np.asarray(image_patches.shape)
    if channel_axis == 0:
        FIXME

    expansion_factor = expansion_factor[1 : (len(expansion_factor) - 1)]

    if verbose:
        print("ExpansionFactor:", str(expansion_factor))

    if image.components == 1:
        image_array = slice_array_channel(prediction, 0, channel_axis)
        prediction_image = ants.make_image(
            (np.asarray(image.shape) * np.asarray(expansion_factor)).astype(int),
            image_array,
        )
        if regression_order is not None:
            reference_image = ants.resample_image_to_target(image, prediction_image)
            prediction_image = regression_match_image(
                prediction_image, reference_image, poly_order=regression_order
            )
    else:
        image_component_list = list()
        for k in range(image.components):
            image_array = slice_array_channel(prediction, k, channel_axis)
            image_component_list.append(
                ants.make_image(
                    (np.asarray(image.shape) * np.asarray(expansion_factor)).astype(
                        int
                    ),
                    image_array,
                )
            )
        prediction_image = ants.merge_channels(image_component_list)

    prediction_image = ants.copy_image_info(image, prediction_image)
    ants.set_spacing(
        prediction_image,
        tuple(np.asarray(image.spacing) / np.asarray(expansion_factor)),
    )

    return prediction_image
Beispiel #10
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             which_model="sysu",
             which_axes=2,
             number_of_simulations=0,
             sd_affine=0.01,
             antsxnet_cache_directory=None,
             verbose=False):
    """
    Perform White matter hyperintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * intensity truncation,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{do_preprocessing = True}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction, intensity truncation, brain extraction.

    which_model : string
        one of:
            * "sysu" -- same as the original sysu network (without site specific preprocessing),
            * "sysu-ri" -- same as "sysu" but using ranked intensity scaling for input images,
            * "sysuWithAttention" -- "sysu" with attention gating,
            * "sysuWithAttentionAndSite" -- "sysu" with attention gating with site branch (see "sysuWithSite"),
            * "sysuPlus" -- "sysu" with attention gating and nn-Unet activation,
            * "sysuPlusSeg" -- "sysuPlus" with deep_atropos segmentation in an additional channel, and
            * "sysuWithSite" -- "sysu" with global pooling on encoding channels to predict "site".
            * "sysuPlusSegWithSite" -- "sysuPlusSeg" combined with "sysuWithSite"
        In addition to both modalities, all models have T1-only and flair-only variants except
        for "sysuPlusSeg" (which only has a T1-only variant) or "sysu-ri" (which has neither single
        modality variant).

    which_axes : string or scalar or tuple/vector
        apply 2-D model to 1 or more axes.  In addition to a scalar
        or vector, e.g., which_axes = (0, 2), one can use "max" for the
        axis with maximum anisotropy (default) or "all" for all axes.

    number_of_simulations : integer
        Number of random affine perturbations to transform the input.

    sd_affine : float
        Define the standard deviation of the affine transformation parameter.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import deep_atropos
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import randomly_transform_image_data
    from ..utilities import pad_or_crop_image_to_size

    do_t1_only = False
    do_flair_only = False

    if flair is None and t1 is not None:
        do_t1_only = True
    elif flair is not None and t1 is None:
        do_flair_only = True

    use_t1_segmentation = False
    if "Seg" in which_model:
        if do_flair_only:
            raise ValueError("Segmentation requires T1.")
        else:
            use_t1_segmentation = True

    if use_t1_segmentation and do_preprocessing == False:
        raise ValueError(
            "Using the t1 segmentation requires do_preprocessing=True.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    do_slicewise = True

    if do_slicewise == False:

        raise ValueError("Not available.")

        # ################################
        # #
        # # Preprocess images
        # #
        # ################################

        # t1_preprocessed = t1
        # t1_preprocessing = None
        # if do_preprocessing == True:
        #     t1_preprocessing = preprocess_brain_image(t1,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         template="croppedMni152",
        #         template_transform_type="antsRegistrationSyNQuickRepro[a]",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        # flair_preprocessed = flair
        # if do_preprocessing == True:
        #     flair_preprocessing = preprocess_brain_image(flair,
        #         truncate_intensity=(0.01, 0.99),
        #         brain_extraction_modality="t1",
        #         do_bias_correction=True,
        #         do_denoising=False,
        #         antsxnet_cache_directory=antsxnet_cache_directory,
        #         verbose=verbose)
        #     flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
        #         moving=flair_preprocessing["preprocessed_image"],
        #         transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
        #     flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        # ################################
        # #
        # # Build model and load weights
        # #
        # ################################

        # patch_size = (112, 112, 112)
        # stride_length = (t1_preprocessed.shape[0] - patch_size[0],
        #                 t1_preprocessed.shape[1] - patch_size[1],
        #                 t1_preprocessed.shape[2] - patch_size[2])

        # classes = ("background", "wmh" )
        # number_of_classification_labels = len(classes)
        # labels = (0, 1)

        # image_modalities = ("T1", "FLAIR")
        # channel_size = len(image_modalities)

        # unet_model = create_unet_model_3d((*patch_size, channel_size),
        #     number_of_outputs = number_of_classification_labels,
        #     number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
        #     convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
        #     weight_decay = 1e-5, additional_options=("attentionGating"))

        # weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
        #     antsxnet_cache_directory=antsxnet_cache_directory)
        # unet_model.load_weights(weights_file_name)

        # ################################
        # #
        # # Do prediction and normalize to native space
        # #
        # ################################

        # if verbose == True:
        #     print("ew_david:  prediction.")

        # batchX = np.zeros((8, *patch_size, channel_size))

        # t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        # t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,0] = t1_patches

        # flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        # flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
        #                                     max_number_of_patches="all", stride_length=stride_length,
        #                                     return_as_array=True)
        # batchX[:,:,:,:,1] = flair_patches

        # predicted_data = unet_model.predict(batchX, verbose=verbose)

        # probability_images = list()
        # for i in range(len(labels)):
        #     print("Reconstructing image", classes[i])
        #     reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
        #         domain_image=t1_preprocessed, stride_length=stride_length)

        #     if do_preprocessing == True:
        #         probability_images.append(ants.apply_transforms(fixed=t1,
        #             moving=reconstructed_image,
        #             transformlist=t1_preprocessing['template_transforms']['invtransforms'],
        #             whichtoinvert=[True], interpolator="linear", verbose=verbose))
        #     else:
        #         probability_images.append(reconstructed_image)

        # return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        use_rank_intensity_scaling = False
        if "-ri" in which_model:
            use_rank_intensity_scaling = True

        t1_preprocessed = None
        t1_preprocessing = None
        brain_mask = None
        if t1 is not None:
            if do_preprocessing == True:
                t1_preprocessing = preprocess_brain_image(
                    t1,
                    truncate_intensity=(0.01, 0.995),
                    brain_extraction_modality="t1",
                    do_bias_correction=False,
                    do_denoising=False,
                    antsxnet_cache_directory=antsxnet_cache_directory,
                    verbose=verbose)
                brain_mask = ants.threshold_image(
                    t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                t1_preprocessed = t1_preprocessing["preprocessed_image"]

        t1_segmentation = None
        if use_t1_segmentation:
            atropos_seg = deep_atropos(t1,
                                       do_preprocessing=True,
                                       verbose=verbose)
            t1_segmentation = atropos_seg['segmentation_image']

        flair_preprocessed = None
        if flair is not None:
            flair_preprocessed = flair
            if do_preprocessing == True:
                if brain_mask is None:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=(0.01, 0.995),
                        brain_extraction_modality="flair",
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                    brain_mask = ants.threshold_image(
                        flair_preprocessing["brain_mask"], 0.5, 1, 1, 0)
                else:
                    flair_preprocessing = preprocess_brain_image(
                        flair,
                        truncate_intensity=None,
                        brain_extraction_modality=None,
                        do_bias_correction=False,
                        do_denoising=False,
                        antsxnet_cache_directory=antsxnet_cache_directory,
                        verbose=verbose)
                flair_preprocessed = flair_preprocessing["preprocessed_image"]

        if t1_preprocessed is not None:
            t1_preprocessed = t1_preprocessed * brain_mask
        if flair_preprocessed is not None:
            flair_preprocessed = flair_preprocessed * brain_mask

        if t1_preprocessed is not None:
            resampling_params = list(ants.get_spacing(t1_preprocessed))
        else:
            resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            if flair_preprocessed is not None:
                flair_preprocessed = ants.resample_image(flair_preprocessed,
                                                         resampling_params,
                                                         use_voxels=False,
                                                         interp_type=0)
            if t1_preprocessed is not None:
                t1_preprocessed = ants.resample_image(t1_preprocessed,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=0)
            if t1_segmentation is not None:
                t1_segmentation = ants.resample_image(t1_segmentation,
                                                      resampling_params,
                                                      use_voxels=False,
                                                      interp_type=1)
            if brain_mask is not None:
                brain_mask = ants.resample_image(brain_mask,
                                                 resampling_params,
                                                 use_voxels=False,
                                                 interp_type=1)

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (208, 208)

        image_modalities = ("T1", "FLAIR")
        if do_flair_only:
            image_modalities = ("FLAIR", )
        elif do_t1_only:
            image_modalities = ("T1", )
        if use_t1_segmentation:
            image_modalities = (*image_modalities, "T1Seg")
        channel_size = len(image_modalities)

        unet_model = None
        if which_model == "sysu" or which_model == "sysu-ri":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuWithAttention":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithAttentionAndSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        elif which_model == "sysuWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("initialConvolutionKernelSize[5]", ))
        elif which_model == "sysuPlusSegWithSite":
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                scalar_output_size=3,
                scalar_output_activation="softmax",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))
        else:
            unet_model = create_unet_model_2d(
                (*template_size, channel_size),
                number_of_outputs=1,
                mode="sigmoid",
                number_of_filters=(64, 96, 128, 256, 512),
                dropout_rate=0.0,
                convolution_kernel_size=(3, 3),
                deconvolution_kernel_size=(2, 2),
                weight_decay=0,
                additional_options=("nnUnetActivationStyle", "attentionGating",
                                    "initialConvolutionKernelSize[5]"))

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = None
        if which_model == "sysu" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysu",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu-ri" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuRankedIntensity",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysu" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttention",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttention" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithAttentionAndSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithAttentionAndSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlus",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlus" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSeg",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSeg" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuPlusSegWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuPlusSegWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSite",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is None and t1 is not None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteT1Only",
                antsxnet_cache_directory=antsxnet_cache_directory)
        elif which_model == "sysuWithSite" and flair is not None and t1 is None:
            weights_file_name = get_pretrained_network(
                "ewDavidSysuWithSiteFlairOnly",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            raise ValueError(
                "Incorrect model specification or image combination.")

        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Data augmentation and extract slices
        #
        ################################

        wmh_probability_image = None
        if t1 is not None:
            wmh_probability_image = ants.image_clone(t1_preprocessed) * 0
        else:
            wmh_probability_image = ants.image_clone(flair_preprocessed) * 0

        wmh_site = np.array([0, 0, 0])

        data_augmentation = None
        if number_of_simulations > 0:
            if do_flair_only:
                data_augmentation = randomly_transform_image_data(
                    reference_image=flair_preprocessed,
                    input_image_list=[[flair_preprocessed]],
                    number_of_simulations=number_of_simulations,
                    transform_type='affine',
                    sd_affine=sd_affine,
                    input_image_interpolator='linear')
            elif do_t1_only:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[t1_preprocessed]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')
            else:
                if use_t1_segmentation:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        segmentation_image_list=[t1_segmentation],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear',
                        segmentation_image_interpolator='nearestNeighbor')
                else:
                    data_augmentation = randomly_transform_image_data(
                        reference_image=t1_preprocessed,
                        input_image_list=[[
                            flair_preprocessed, t1_preprocessed
                        ]],
                        number_of_simulations=number_of_simulations,
                        transform_type='affine',
                        sd_affine=sd_affine,
                        input_image_interpolator='linear')

        dimensions_to_predict = list((0, ))
        if which_axes == "max":
            spacing = ants.get_spacing(wmh_probability_image)
            dimensions_to_predict = (spacing.index(max(spacing)), )
        elif which_axes == "all":
            dimensions_to_predict = list(range(3))
        else:
            if isinstance(which_axes, int):
                dimensions_to_predict = list((which_axes, ))
            else:
                dimensions_to_predict = list(which_axes)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += wmh_probability_image.shape[
                dimensions_to_predict[d]]

        batchX = np.zeros(
            (total_number_of_slices, *template_size, channel_size))

        for n in range(number_of_simulations + 1):

            batch_flair = flair_preprocessed
            batch_t1 = t1_preprocessed
            batch_t1_segmentation = t1_segmentation
            batch_brain_mask = brain_mask

            if n > 0:

                if do_flair_only:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")

                elif do_t1_only:
                    batch_t1 = data_augmentation['simulated_images'][n - 1][0]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        t1_preprocessed,
                        interpolation="nearestneighbor")
                else:
                    batch_flair = data_augmentation['simulated_images'][n -
                                                                        1][0]
                    batch_t1 = data_augmentation['simulated_images'][n - 1][1]
                    batch_brain_mask = ants.apply_ants_transform_to_image(
                        data_augmentation['simulated_transforms'][n - 1],
                        brain_mask,
                        flair_preprocessed,
                        interpolation="nearestneighbor")
                if use_t1_segmentation:
                    batch_t1_segmentation = data_augmentation[
                        'simulated_segmentation_images'][n - 1]

            if use_rank_intensity_scaling:
                if batch_t1 is not None:
                    batch_t1 = ants.rank_intensity(batch_t1,
                                                   batch_brain_mask) - 0.5
                if batch_flair is not None:
                    batch_flair = ants.rank_intensity(flair_preprocessed,
                                                      batch_brain_mask) - 0.5
            else:
                if batch_t1 is not None:
                    batch_t1 = (batch_t1 -
                                batch_t1[batch_brain_mask == 1].mean()
                                ) / batch_t1[batch_brain_mask == 1].std()
                if batch_flair is not None:
                    batch_flair = (
                        batch_flair -
                        batch_flair[batch_brain_mask == 1].mean()
                    ) / batch_flair[batch_brain_mask == 1].std()

            slice_count = 0
            for d in range(len(dimensions_to_predict)):

                number_of_slices = None
                if batch_t1 is not None:
                    number_of_slices = batch_t1.shape[dimensions_to_predict[d]]
                else:
                    number_of_slices = batch_flair.shape[
                        dimensions_to_predict[d]]

                if verbose == True:
                    print("Extracting slices for dimension ",
                          dimensions_to_predict[d])

                for i in range(number_of_slices):

                    brain_mask_slice = pad_or_crop_image_to_size(
                        ants.slice_image(batch_brain_mask,
                                         dimensions_to_predict[d], i),
                        template_size)

                    channel_count = 0
                    if batch_flair is not None:
                        flair_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_flair,
                                             dimensions_to_predict[d], i),
                            template_size)
                        flair_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = flair_slice.numpy()
                        channel_count += 1
                    if batch_t1 is not None:
                        t1_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_slice.numpy()
                        channel_count += 1
                    if t1_segmentation is not None:
                        t1_segmentation_slice = pad_or_crop_image_to_size(
                            ants.slice_image(batch_t1_segmentation,
                                             dimensions_to_predict[d], i),
                            template_size)
                        t1_segmentation_slice[brain_mask_slice == 0] = 0
                        batchX[slice_count, :, :,
                               channel_count] = t1_segmentation_slice.numpy(
                               ) / 6 - 0.5

                    slice_count += 1

            ################################
            #
            # Do prediction and then restack into the image
            #
            ################################

            if verbose == True:
                if n == 0:
                    print("Prediction")
                else:
                    print("Prediction (simulation " + str(n) + ")")

            prediction = unet_model.predict(batchX, verbose=verbose)

            permutations = list()
            permutations.append((0, 1, 2))
            permutations.append((1, 0, 2))
            permutations.append((1, 2, 0))

            prediction_image_average = ants.image_clone(
                wmh_probability_image) * 0

            current_start_slice = 0
            for d in range(len(dimensions_to_predict)):
                current_end_slice = current_start_slice + wmh_probability_image.shape[
                    dimensions_to_predict[d]]
                which_batch_slices = range(current_start_slice,
                                           current_end_slice)
                if isinstance(prediction, list):
                    prediction_per_dimension = prediction[0][
                        which_batch_slices, :, :, 0]
                else:
                    prediction_per_dimension = prediction[
                        which_batch_slices, :, :, 0]
                prediction_array = np.transpose(
                    np.squeeze(prediction_per_dimension),
                    permutations[dimensions_to_predict[d]])
                prediction_image = ants.copy_image_info(
                    wmh_probability_image,
                    pad_or_crop_image_to_size(
                        ants.from_numpy(prediction_array),
                        wmh_probability_image.shape))
                prediction_image_average = prediction_image_average + (
                    prediction_image - prediction_image_average) / (d + 1)
                current_start_slice = current_end_slice

            wmh_probability_image = wmh_probability_image + (
                prediction_image_average - wmh_probability_image) / (n + 1)
            if isinstance(prediction, list):
                wmh_site = wmh_site + (np.mean(prediction[1], axis=0) -
                                       wmh_site) / (n + 1)

        if do_resampling:
            if t1 is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, t1)
            if flair is not None:
                wmh_probability_image = ants.resample_image_to_target(
                    wmh_probability_image, flair)

        if isinstance(prediction, list):
            return ([wmh_probability_image, wmh_site])
        else:
            return (wmh_probability_image)
Beispiel #11
0
def arterial_lesion_segmentation(image,
                                 antsxnet_cache_directory=None,
                                 verbose=False):
    """
    Perform arterial lesion segmentation using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Foreground probability image.

    Example
    -------
    >>> output = arterial_lesion_segmentation(histology_image)
    """

    from ..architectures import create_unet_model_2d
    from ..utilities import get_pretrained_network

    if image.dimension != 2:
        raise ValueError("Image dimension must be 2.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    channel_size = 1

    weights_file_name = get_pretrained_network(
        "arterialLesionWeibinShi",
        antsxnet_cache_directory=antsxnet_cache_directory)

    resampled_image_size = (512, 512)

    unet_model = create_unet_model_2d(
        (*resampled_image_size, channel_size),
        number_of_outputs=1,
        mode="sigmoid",
        number_of_filters=(64, 96, 128, 256, 512),
        convolution_kernel_size=(3, 3),
        deconvolution_kernel_size=(2, 2),
        dropout_rate=0.0,
        weight_decay=0,
        additional_options=("initialConvolutionKernelSize[5]",
                            "attentionGating"))
    unet_model.load_weights(weights_file_name)

    if verbose == True:
        print("Preprocessing:  Resampling and N4 bias correction.")

    preprocessed_image = ants.image_clone(image)
    preprocessed_image = preprocessed_image / preprocessed_image.max()
    preprocessed_image = ants.resample_image(preprocessed_image,
                                             resampled_image_size,
                                             use_voxels=True,
                                             interp_type=0)
    mask = ants.image_clone(preprocessed_image) * 0 + 1
    preprocessed_image = ants.n4_bias_field_correction(preprocessed_image,
                                                       mask=mask,
                                                       shrink_factor=2,
                                                       return_bias_field=False,
                                                       verbose=verbose)

    batchX = np.expand_dims(preprocessed_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.min()) / (batchX.max() - batchX.min())

    predicted_data = unet_model.predict(batchX, verbose=int(verbose))

    origin = preprocessed_image.origin
    spacing = preprocessed_image.spacing
    direction = preprocessed_image.direction

    foreground_probability_image = ants.from_numpy(np.squeeze(
        predicted_data[0, :, :, 0]),
                                                   origin=origin,
                                                   spacing=spacing,
                                                   direction=direction)

    if verbose == True:
        print("Post-processing:  resampling to original space.")

    foreground_probability_image = ants.resample_image_to_target(
        foreground_probability_image, image)

    return (foreground_probability_image)