Пример #1
0
def iso_resample(img, spacing, islabel=False):
    if islabel:
        img_re = ants.resample_image(img, spacing, False, 1)
    else:
        img_re = ants.resample_image(img, spacing, False, 0)

    return img_re
Пример #2
0
    def test_example(self):
        # test ANTsPy/ANTsR example
        fixed = ants.image_read(ants.get_ants_data("r16"))
        moving = ants.image_read(ants.get_ants_data("r64"))
        fixed = ants.resample_image(fixed, (64, 64), 1, 0)
        moving = ants.resample_image(moving, (64, 64), 1, 0)
        mytx = ants.registration(fixed=fixed,
                                 moving=moving,
                                 type_of_transform="SyN")
        mywarpedimage = ants.apply_transforms(
            fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"])

        # bad interpolator
        with self.assertRaises(Exception):
            mywarpedimage = ants.apply_transforms(
                fixed=fixed,
                moving=moving,
                transformlist=mytx["fwdtransforms"],
                interpolator="unsupported-interp",
            )

        # transform doesnt exist
        with self.assertRaises(Exception):
            mywarpedimage = ants.apply_transforms(
                fixed=fixed,
                moving=moving,
                transformlist=["blah-blah.mat"],
                interpolator="unsupported-interp",
            )
Пример #3
0
def iso_resample(img, spacing=None, islabel=False):
    if spacing is None:
        spacing = [1.25, 1.25, 10]
    if islabel:
        img_re = ants.resample_image(img, spacing, False, 1)
    else:
        img_re = ants.resample_image(img, spacing, False, 0)

    return img_re
Пример #4
0
 def test_example(self):
     fi = ants.image_read( ants.get_ants_data('r16'))
     mi = ants.image_read( ants.get_ants_data('r64'))
     fi = ants.resample_image(fi,(128,128),1,0)
     mi = ants.resample_image(mi,(128,128),1,0)
     mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
     try:
         jac = ants.create_jacobian_determinant_image(fi,mytx['fwdtransforms'][0],1)
     except:
         pass
Пример #5
0
    def test_registration_types(self):
        print('Starting long registration interface test')
        fi = ants.image_read(ants.get_ants_data('r16'))
        mi = ants.image_read(ants.get_ants_data('r64'))
        fi = ants.resample_image(fi, (60,60), 1, 0)
        mi = ants.resample_image(mi, (60,60), 1, 0)

        for ttype in self.transform_types:
            mytx = ants.registration(fixed=fi, moving=mi, type_of_transform=ttype)

            # with mask
            fimask = fi > fi.mean()
            mytx = ants.registration(fixed=fi, moving=mi, mask=fimask, type_of_transform=ttype)   
        print('Finished long registration interface test')
Пример #6
0
    def test_resample_returns_NaNs(self):
        """
        Test that resampling an image doesnt cause the resampled
        image to have NaNs - previously caused by resampling an
        image of type DOUBLE
        """
        img2d = ants.image_read(ants.get_ants_data('r16'))
        img2dr = ants.resample_image(img2d, (2, 2), 0, 0)

        self.assertTrue(np.sum(np.isnan(img2dr.numpy())) == 0)

        img3d = ants.image_read(ants.get_ants_data('mni'))
        img3dr = ants.resample_image(img3d, (2, 2, 2), 0, 0)

        self.assertTrue(np.sum(np.isnan(img3dr.numpy())) == 0)
Пример #7
0
def main(args):

    logfile = args['logfile']
    save_directory = args['save_directory']
    warp_directory = args['warp_directory']

    fixed_path = args['fixed_path']
    fixed_fly = args['fixed_fly']
    fixed_resolution = args['fixed_resolution']

    moving_path = args['moving_path']
    moving_fly = args['moving_fly']
    moving_resolution = args['moving_resolution']

    ###################
    ### Load Brains ###
    ###################
    fixed = np.asarray(nib.load(fixed_path).get_data().squeeze(),
                       dtype='float32')
    fixed = ants.from_numpy(fixed)
    fixed.set_spacing(fixed_resolution)
    fixed = ants.resample_image(fixed, (256, 128, 49), 1, 0)

    moving = np.asarray(nib.load(moving_path).get_data().squeeze(),
                        dtype='float32')
    moving = ants.from_numpy(moving)
    moving.set_spacing(moving_resolution)

    ###########################
    ### Organize Transforms ###
    ###########################
    affine_file = os.listdir(
        os.path.join(warp_directory, 'func-to-anat_fwdtransforms_lowres'))[0]
    affine_path = os.path.join(warp_directory,
                               'func-to-anat_fwdtransforms_lowres',
                               affine_file)

    syn_files = os.listdir(
        os.path.join(warp_directory, 'anat-to-meanbrain_fwdtransforms_lowres'))
    syn_linear_path = os.path.join(warp_directory,
                                   'anat-to-meanbrain_fwdtransforms_lowres',
                                   [x for x in syn_files if '.mat' in x][0])
    syn_nonlinear_path = os.path.join(
        warp_directory, 'anat-to-meanbrain_fwdtransforms_lowres',
        [x for x in syn_files if '.nii.gz' in x][0])

    transforms = [affine_path, syn_linear_path, syn_nonlinear_path]

    ########################
    ### Apply Transforms ###
    ########################
    moco = ants.apply_transforms(fixed, moving, transforms, imagetype=3)

    ############
    ### Save ###
    ############
    save_file = os.path.join(
        save_directory, 'functional_channel_2_moco_zscore_highpass_warped.nii'
    )  #<---------------------------------------
    nib.Nifti1Image(moco.numpy(), np.eye(4)).to_filename(save_file)
Пример #8
0
    def test_example(self):
        ref = ants.image_read(ants.get_ants_data('r16'))
        ref = ants.resample_image(ref, (50, 50), 1, 0)
        ref = ants.iMath(ref, 'Normalize')
        mi = ants.image_read(ants.get_ants_data('r27'))
        mi2 = ants.image_read(ants.get_ants_data('r30'))
        mi3 = ants.image_read(ants.get_ants_data('r62'))
        mi4 = ants.image_read(ants.get_ants_data('r64'))
        mi5 = ants.image_read(ants.get_ants_data('r85'))
        refmask = ants.get_mask(ref)
        refmask = ants.iMath(refmask, 'ME', 2)  # just to speed things up
        ilist = [mi, mi2, mi3, mi4, mi5]
        seglist = [None] * len(ilist)
        for i in range(len(ilist)):
            ilist[i] = ants.iMath(ilist[i], 'Normalize')
            mytx = ants.registration(fixed=ref,
                                     moving=ilist[i],
                                     typeofTransform=('Affine'))
            mywarpedimage = ants.apply_transforms(
                fixed=ref,
                moving=ilist[i],
                transformlist=mytx['fwdtransforms'])
            ilist[i] = mywarpedimage
            seg = ants.threshold_image(ilist[i], 'Otsu', 3)
            seglist[i] = (seg) + ants.threshold_image(seg, 1, 3).morphology(
                operation='dilate', radius=3)

        r = 2
        pp = ants.joint_label_fusion(ref,
                                     refmask,
                                     ilist,
                                     r_search=2,
                                     label_list=seglist,
                                     rad=[r] * ref.dimension)
        pp = ants.joint_label_fusion(ref, refmask, ilist, r_search=2, rad=2)
Пример #9
0
def read_nib(addr):
    # img = np.array(nib.load(addr).dataobj)
    img = nib.load(addr)
    ants_data = ants.from_nibabel(img)
    resampled = ants.resample_image(ants_data, (1.5, 1.5, 1.5)).numpy()  # Resample to 1.5mm3 resolution
    print("Resampled Size: ", resampled.shape)
    np_resampled_pad = np.pad(resampled, 100, 'constant', constant_values=resampled.min())
    # np_resampled_pad = np.pad(img, 100, 'constant', constant_values=img.min())
    return crop_center(np_resampled_pad, 256, 256, 400)
Пример #10
0
 def test_example(self):
     img = ants.image_read(ants.get_ants_data('r16'), 2)
     img = ants.resample_image(img, (64, 64), 1, 0)
     mask = ants.get_mask(img)
     segs = ants.kmeans_segmentation(img, k=3, kmask=mask)
     thick = ants.kelly_kapowski(s=segs['segmentation'],
                                 g=segs['probabilityimages'][1],
                                 w=segs['probabilityimages'][2],
                                 its=45,
                                 r=0.5,
                                 m=1)
def Transform_Image(imagePath):
    #Necesario para que el relative path funcione de manera correcta,
    my_path = os.path.abspath(os.path.dirname(__file__))
    #Utilizamos la plantilla MANI152
    templatePath = os.path.join(my_path, "../Datos/MNI152.nii.gz")
    #Cargamos el path que recibe por parametro, si el path da erroneo, ha de cambiarse en el código de Busqueda.
    path = os.path.join(my_path, imagePath)
    #Leemos los paths y los transformamos en imagenes
    template = ants.image_read(templatePath)
    image = ants.image_read(path)
    #Reesamblamos
    template = ants.resample_image(template, (64, 64, 64), 1, 0)
    image = ants.resample_image(image, (64, 64, 64), 1, 0)
    #Realizamos la transformación
    mytx = ants.registration(fixed=template,
                             moving=image,
                             type_of_transform='SyN')
    mywarpedimage = ants.apply_transforms(fixed=template,
                                          moving=image,
                                          transformlist=mytx['fwdtransforms'])
    return mywarpedimage, mytx
Пример #12
0
    def resampleANTs(mm_spacing, ANTsImageObject, file_id, method):
        """Function aiming at rescaling imaging to a resolution specified somewhere, e.g. in the cfg-file"""
        import math, ants

        resolution = [mm_spacing] * 3
        if not all([
                math.isclose(float(mm_spacing), x, abs_tol=10**-2)
                for x in ANTsImageObject.spacing
        ]):
            print(
                '\tImage spacing {:.4f}x{:.4f}x{:.4f} unequal to specified value ({}mm). '
                '\n\t\tRescaling {}'.format(ANTsImageObject.spacing[0],
                                            ANTsImageObject.spacing[1],
                                            ANTsImageObject.spacing[2],
                                            mm_spacing, file_id))

            if len(ANTsImageObject.spacing) == 4:
                resolution.append(ANTsImageObject.spacing[-1])
            elif len(ANTsImageObject.spacing) > 4:
                Output.msg_box(
                    text='\tSequences of >4 dimensions are not possible.',
                    title='Too many dimensions')

            resampled_image = ants.resample_image(ANTsImageObject,
                                                  resolution,
                                                  use_voxels=False,
                                                  interp_type=method)
            ants.image_write(resampled_image, filename=file_id)
        else:
            print(
                '\tImage spacing for sequence: {} is {:.4f}x{:.4f}x{:.4f} as specified in options, '
                '\t -> proceeding!'.format(file_id, ANTsImageObject.spacing[0],
                                           ANTsImageObject.spacing[1],
                                           ANTsImageObject.spacing[2]))
            resampled_image = ANTsImageObject

        return resampled_image
Пример #13
0
    def test_multiple_inputs(self):
        img = ants.image_read(ants.get_ants_data("r16"))
        img = ants.resample_image(img, (64, 64), 1, 0)
        mask = ants.get_mask(img)
        segs1 = ants.atropos(a=img,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i='kmeans[3]',
                             x=mask)

        # Use probabilities from k-means seg as priors
        segs2 = ants.atropos(a=img,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i=segs1['probabilityimages'],
                             x=mask)

        # multiple inputs
        feats = [img, ants.iMath(img, "Laplacian"), ants.iMath(img, "Grad")]
        segs3 = ants.atropos(a=feats,
                             m='[0.2,1x1]',
                             c='[2,0]',
                             i=segs1['probabilityimages'],
                             x=mask)
Пример #14
0
print("Preprocessing: bias correction")
image_n4 = ants.n4_bias_field_correction(image)
image_n4 = ants.image_math(image_n4, 'Normalize') * 255.0

print("Preprocessing:  thresholding")
image_n4_array = ((image_n4.numpy()).flatten())
image_n4_nonzero = image_n4_array[(image_n4_array > 0).nonzero()]
image_robust_range = np.quantile(image_n4_nonzero, (0.02, 0.98))
threshold_value = 0.10 * (image_robust_range[1] -
                          image_robust_range[0]) + image_robust_range[0]
thresholded_mask = ants.threshold_image(image_n4, -10000, threshold_value, 0,
                                        1)
thresholded_image = image_n4 * thresholded_mask

print("Preprocessing:  resampling")
image_resampled = ants.resample_image(thresholded_image, (256, 256, 256), True)
batchX = np.expand_dims(image_resampled.numpy(), axis=0)
batchX = np.expand_dims(batchX, axis=-1)

print("Prediction and write to disk.")
brain_mask_array = model.predict(batchX, verbose=0)
brain_mask_resampled = ants.from_numpy(np.squeeze(brain_mask_array[0, :, :, :,
                                                                   0]),
                                       origin=image_resampled.origin,
                                       spacing=image_resampled.spacing,
                                       direction=image_resampled.direction)
brain_mask_image = ants.resample_image(brain_mask_resampled, image.shape, True,
                                       1)
minimum_brain_volume = round(649933.7)
brain_mask_labeled = ants.label_clusters(brain_mask_image,
                                         minimum_brain_volume)
Пример #15
0
 def test_example(self):
     # test ANTsPy/ANTsR example
     img = ants.image_read(ants.get_ants_data('r16'))
     img = ants.resample_image(img, (64, 64), 1, 0)
     mask = ants.get_mask(img)
     ants.atropos(a=img, m='[0.2,1x1]', c='[2,0]', i='kmeans[3]', x=mask)
Пример #16
0
 def test_label_stats_example(self):
     image = ants.image_read(ants.get_ants_data('r16'), 2)
     image = ants.resample_image(image, (64, 64), 1, 0)
     mask = image > image.mean()
     segs1 = ants.kmeans_segmentation(image, 3)
     stats = ants.label_stats(image, segs1['segmentation'])
Пример #17
0
def preprocess(img_dir,
               out_dir,
               mask_dir=None,
               res=(1., 1., 1.),
               orientation='RAI',
               n4_opts=None):
    """
    preprocess.py MR images according to a simple scheme,
    that is:
        1) N4 bias field correction
        2) resample to x mm x y mm x z mm
        3) reorient images to RAI

    Args:
        img_dir (str): path to directory containing images
        out_dir (str): path to directory for output preprocessed files
        mask_dir (str): path to directory containing masks
        res (tuple): resolution for resampling (default: (1,1,1) in mm)
        n4_opts (dict): n4 processing options. See ANTsPy for details. (default: None)

    Returns:
        None, outputs preprocessed images to file in given out_dir
    """

    if n4_opts is None:
        n4_opts = {'iters': [200, 200, 200, 200], 'tol': 0.0005}
    logger.debug('N4 Options are: {}'.format(n4_opts))

    # get and check the images and masks
    img_fns = glob_nii(img_dir)
    mask_fns = glob_nii(
        mask_dir) if mask_dir is not None else [None] * len(img_fns)
    assert len(img_fns) == len(mask_fns), 'Number of images and masks must be equal ({:d} != {:d})' \
        .format(len(img_fns), len(mask_fns))

    # create the output directory structure
    out_img_dir = os.path.join(out_dir, 'imgs')
    out_mask_dir = os.path.join(out_dir, 'masks')
    if not os.path.exists(out_dir):
        logger.info('Making output directory structure: {}'.format(out_dir))
        os.mkdir(out_dir)
    if not os.path.exists(out_img_dir):
        logger.info('Making image output directory: {}'.format(out_img_dir))
        os.mkdir(out_img_dir)
    if not os.path.exists(out_mask_dir) and mask_dir is not None:
        logger.info('Making mask output directory: {}'.format(out_mask_dir))
        os.mkdir(out_mask_dir)

    # preprocess the images by n4 correction, resampling, and reorientation
    for i, (img_fn, mask_fn) in enumerate(zip(img_fns, mask_fns), 1):
        _, img_base, img_ext = split_filename(img_fn)
        logger.info('Preprocessing image: {} ({:d}/{:d})'.format(
            img_base, i, len(img_fns)))
        img = ants.image_read(img_fn)
        if mask_dir is not None:
            _, mask_base, mask_ext = split_filename(mask_fn)
            mask = ants.image_read(mask_fn)
            smoothed_mask = ants.smooth_image(mask, 1)
            # this should be a second n4 after an initial n4 (and coregistration), once masks are obtained
            img = ants.n4_bias_field_correction(img,
                                                convergence=n4_opts,
                                                weight_mask=smoothed_mask)
            if res is not None:
                if res != img.spacing:
                    mask = ants.resample_image(mask, res, False, 1)
            mask = mask.reorient_image2(orientation) if hasattr(img, 'reorient_image2') else \
                mask.reorient_image((1, 0, 0))['reoimage']
            out_mask = os.path.join(out_mask_dir, mask_base + mask_ext)
            ants.image_write(mask, out_mask)
        else:
            img = ants.n4_bias_field_correction(img, convergence=n4_opts)
        if res is not None:
            if res != img.spacing:
                img = ants.resample_image(img, res, False, 4)
        if hasattr(img, 'reorient_image2'):
            img = img.reorient_image2(orientation)
        else:
            logger.info(
                'Cannot reorient image to a custom orientation. Update ANTsPy to a version >= 0.1.5.'
            )
            img = img.reorient_image((1, 0, 0))['reoimage']
        logger.info('Writing preprocessed image: {} ({:d}/{:d})'.format(
            img_base, i, len(img_fns)))
        out_img = os.path.join(out_img_dir, img_base + img_ext)
        ants.image_write(img, out_img)
Пример #18
0
 def test_resample_image_to_target_example(self):
     fi = ants.image_read(ants.get_ants_data("r16"))
     fi2mm = ants.resample_image(fi, (2, 2), use_voxels=0, interp_type=1)
     resampled = ants.resample_image_to_target(fi2mm, fi, verbose=True)
Пример #19
0
 def test_resample_image_example(self):
     fi = ants.image_read(ants.get_ants_data("r16"))
     finn = ants.resample_image(fi, (50, 60), True, 0)
     filin = ants.resample_image(fi, (1.5, 1.5), False, 1)
Пример #20
0
 def test_example(self):
     fi = ants.image_read(ants.get_ants_data("r16"))
     mi = ants.image_read(ants.get_ants_data("r64"))
     fi = ants.resample_image(fi, (60, 60), 1, 0)
     mi = ants.resample_image(mi, (60, 60), 1, 0)
     mytx = ants.registration(fixed=fi, moving=mi, type_of_transform="SyN")
Пример #21
0
def mainRegScript(patientPath,SA_name,LA_4CH_name,LA_2CH_name,pathSave,typeRe):
    #Load initial images
    SA = ants.image_read(patientPath+SA_name)
    SA = ants.resample_image(SA,[min(SA.spacing),min(SA.spacing),min(SA.spacing)])
    LA_4CH = ants.image_read(patientPath+LA_4CH_name)
    LA_2CH = ants.image_read(patientPath+LA_2CH_name)

    #Registration
    regSA_4CH = ants.registration(SA,LA_4CH,type_of_transform = typeRe,aff_metric = 'mattes')
    regSA_2CH = ants.registration(SA,LA_2CH,type_of_transform = typeRe,aff_metric = 'mattes')
    print('Registration Done')

    #ROI extraction
    clone4CH = ants.image_clone(LA_4CH)
    clone2CH = ants.image_clone(LA_2CH)
    roi_4CH = overROI(clone4CH,regSA_4CH['fwdtransforms'],SA)
    roi_2CH = overROI(clone2CH,regSA_2CH['fwdtransforms'],SA)
    roi_4CHArr = roi_4CH.view()
    roi_4CHArr = roi_4CHArr == 1
    roi_2CHArr = roi_2CH.view()
    roi_2CHArr = roi_2CHArr == 1

    print('ROI extraction Done')

    new4CH = ants.to_nibabel(regSA_4CH['warpedmovout'])
    new4CH.set_data_dtype('int16')
    new2CH = ants.to_nibabel(regSA_2CH['warpedmovout'])
    new2CH.set_data_dtype('int16')
    nib.save(new4CH,pathSave + '4CH.nii')
    nib.save(new2CH,pathSave + '2CH.nii')

    #Normalization

    new4CH_CH = regSA_4CH['warpedmovout'].view()
    new2CH_CH = regSA_2CH['warpedmovout'].view()
    short_CH = SA.view()
    for t in range(0,new2CH_CH.shape[2]):
        short_CH[:,:,t] = normOver(short_CH[:,:,t],roi_4CHArr[:,:,t]+roi_2CHArr[:,:,t])
        new4CH_CH[:,:,t] = normOver(new4CH_CH[:,:,t],roi_4CHArr[:,:,t])
        new2CH_CH[:,:,t] = normOver(new2CH_CH[:,:,t],roi_2CHArr[:,:,t])


    new4CH = ants.to_nibabel(regSA_4CH['warpedmovout'])
    new4CH.set_data_dtype('int16')
    new2CH = ants.to_nibabel(regSA_2CH['warpedmovout'])
    new2CH.set_data_dtype('int16')
    shortNorm = ants.to_nibabel(SA)
    shortNorm.set_data_dtype('int16')
    nib.save(new4CH,pathSave + 'norm_4CH.nii')
    nib.save(new2CH,pathSave + 'norm_2CH.nii')
    nib.save(shortNorm,pathSave + 'norm_SA.nii')
    print('Normalization Done')

    #Checkboard
    ch_4CH = np.zeros(short_CH.shape)
    ch_2CH =np.zeros(short_CH.shape)
    for t in np.arange(0,short_CH.shape[2]):
        ch_4CH[:,:,t] = cheBoard(short_CH[:,:,t],new4CH_CH[:,:,t],16,roi_4CHArr[:,:,t])
        ch_2CH[:,:,t] = cheBoard(short_CH[:,:,t],new2CH_CH[:,:,t],16,roi_2CHArr[:,:,t])
    chest_4CH = nib.Nifti1Image(ch_4CH,new4CH.affine,new4CH.header)
    chest_2CH = nib.Nifti1Image(ch_2CH,new2CH.affine,new2CH.header)
    print("Checkboard filter applied")
    nib.save(chest_4CH,pathSave + 'chest_4CH.nii')
    nib.save(chest_2CH,pathSave + 'chest_2CH.nii')
    print("Done")
    final = (SA,regSA_4CH,regSA_2CH)
    return final
import ants
import os

# Constants for path names
FIXED_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/img/0007.nii.gz"
OLD_TRAIN_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/img/"
NEW_TRAIN_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/img_registered/"
OLD_TRAIN_LABELS = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/label/"
NEW_TRAIN_LABELS = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/label_registered/"
OLD_VAL_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Val/img/"
NEW_VAL_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Val/img_registered/"
OLD_VAL_LABELS = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Val/label/"
NEW_VAL_LABELS = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Val/label_registered/"

fixed = ants.image_read(FIXED_IMG)
fixed = ants.resample_image(fixed, [224, 224, 70], True, 1)

# Register all the training images
for file_name in os.listdir(OLD_TRAIN_IMG):
    moving_image = ants.image_read(OLD_TRAIN_IMG + file_name)
    moving_image = ants.resample_image(moving_image, [224, 224, 70], True, 1)
    label = ants.image_read(OLD_TRAIN_LABELS + file_name)
    label = ants.resample_image(label, [224, 224, 70], True, 1)
    transform = ants.registration(fixed=fixed , moving=moving_image,
                                 type_of_transform='Affine' )
    transformed_image = ants.apply_transforms( fixed=fixed, moving=moving_image,
                                               transformlist=transform['fwdtransforms'],
                                               interpolator='nearestNeighbor')
    transformed_image.to_file(NEW_TRAIN_IMG + file_name)
    transformed_label = ants.apply_transforms( fixed=fixed, moving=label,
                                               transformlist=transform['fwdtransforms'],
data_initial_stage = image_resampled.numpy()
data_initial_stage = np.expand_dims(data_initial_stage, 0)
data_initial_stage = np.expand_dims(data_initial_stage, -1)

prediction_initial_stage = np.squeeze(
    model_initial_stage.predict(data_initial_stage))
prediction_initial_stage[np.where(prediction_initial_stage >= 0.5)] = 1
prediction_initial_stage[np.where(prediction_initial_stage < 0.5)] = 0
mask_initial_stage = ants.from_numpy(prediction_initial_stage,
                                     origin=image_resampled.origin,
                                     spacing=image_resampled.spacing,
                                     direction=image_resampled.direction)
mask_initial_stage = ants.label_clusters(mask_initial_stage,
                                         min_cluster_size=10)
mask_initial_stage = ants.threshold_image(mask_initial_stage, 1, 2, 1, 0)
mask_initial_stage_original_space = ants.resample_image(
    mask_initial_stage, image_n4.shape, True, 1)
end_time = time.time()
elapsed_time = end_time - start_time
print("  (elapsed time: ", elapsed_time, " seconds)")

#########################################
#
# Perform initial (stage 2) segmentation
#

print("")
print("")
print("*************  Refine stage segmentation  ***************")
# print("  (warning:  These steps need closer inspection.)")
print("")
Пример #24
0
def main(args):

    logfile = args['logfile']
    save_directory = args['save_directory']
    flip_X = args['flip_X']
    flip_Z = args['flip_Z']
    type_of_transform = args['type_of_transform']  # SyN or Affine
    save_warp_params = args['save_warp_params']

    fixed_path = args['fixed_path']
    fixed_fly = args['fixed_fly']
    fixed_resolution = args['fixed_resolution']

    moving_path = args['moving_path']
    moving_fly = args['moving_fly']
    moving_resolution = args['moving_resolution']

    low_res = args['low_res']
    very_low_res = args['very_low_res']

    iso_2um_fixed = args['iso_2um_fixed']
    iso_2um_moving = args['iso_2um_moving']

    grad_step = args['grad_step']
    flow_sigma = args['flow_sigma']
    total_sigma = args['total_sigma']
    syn_sampling = args['syn_sampling']

    try:
        mimic_path = args['mimic_path']
        mimic_fly = args['mimic_fly']
        mimic_resolution = args['mimic_resolution']
    except:
        mimic_path = None
        mimic_fly = None
        mimic_resolution = None

    width = 120
    printlog = getattr(flow.Printlog(logfile=logfile), 'print_to_log')

    ###################
    ### Load Brains ###
    ###################

    ### Fixed
    fixed = np.asarray(nib.load(fixed_path).get_data().squeeze(),
                       dtype='float32')
    fixed = ants.from_numpy(fixed)
    fixed.set_spacing(fixed_resolution)
    if low_res:
        fixed = ants.resample_image(fixed, (256, 128, 49), 1, 0)
    elif very_low_res:
        fixed = ants.resample_image(fixed, (128, 64, 49), 1, 0)
    elif iso_2um_fixed:
        fixed = ants.resample_image(fixed, (2, 2, 2), use_voxels=False)

    ### Moving
    moving = np.asarray(nib.load(moving_path).get_data().squeeze(),
                        dtype='float32')
    if flip_X:
        moving = moving[::-1, :, :]
    if flip_Z:
        moving = moving[:, :, ::-1]
    moving = ants.from_numpy(moving)
    moving.set_spacing(moving_resolution)
    if low_res:
        moving = ants.resample_image(moving, (256, 128, 49), 1, 0)
    elif very_low_res:
        moving = ants.resample_image(moving, (128, 64, 49), 1, 0)
    elif iso_2um_moving:
        moving = ants.resample_image(moving, (2, 2, 2), use_voxels=False)

    ### Mimic
    if mimic_path is not None:
        mimic = np.asarray(nib.load(mimic_path).get_data().squeeze(),
                           dtype='float32')
        if flip_X:
            mimic = mimic[::-1, :, :]
        if flip_Z:
            mimic = mimic[:, :, ::-1]
        mimic = ants.from_numpy(mimic)
        mimic.set_spacing(mimic_resolution)
        printlog('Starting {} to {}, with mimic {}'.format(
            moving_fly, fixed_fly, mimic_fly))
    else:
        printlog('Starting {} to {}'.format(moving_fly, fixed_fly))

    #############
    ### Align ###
    #############

    t0 = time()
    with stderr_redirected(
    ):  # to prevent dumb itk gaussian error bullshit infinite printing
        moco = ants.registration(fixed,
                                 moving,
                                 type_of_transform=type_of_transform,
                                 grad_step=grad_step,
                                 flow_sigma=flow_sigma,
                                 total_sigma=total_sigma,
                                 syn_sampling=syn_sampling)

    printlog('Fixed: {}, {} | Moving: {}, {} | {} | {}'.format(
        fixed_fly,
        fixed_path.split('/')[-1], moving_fly,
        moving_path.split('/')[-1], type_of_transform,
        sec_to_hms(time() - t0)))

    ################################
    ### Save warp params if True ###
    ################################

    if save_warp_params:
        fwdtransformlist = moco['fwdtransforms']
        fwdtransforms_save_dir = os.path.join(
            save_directory,
            '{}-to-{}_fwdtransforms'.format(moving_fly, fixed_fly))
        if low_res:
            fwdtransforms_save_dir += '_lowres'
        if not os.path.exists(fwdtransforms_save_dir):
            os.mkdir(fwdtransforms_save_dir)
        for source_path in fwdtransformlist:
            source_file = source_path.split('/')[-1]
            target_path = os.path.join(fwdtransforms_save_dir, source_file)
            copyfile(source_path, target_path)

    # Added this saving of inv transforms 2020 Dec 19
    if save_warp_params:
        fwdtransformlist = moco['invtransforms']
        fwdtransforms_save_dir = os.path.join(
            save_directory,
            '{}-to-{}_invtransforms'.format(moving_fly, fixed_fly))
        if low_res:
            fwdtransforms_save_dir += '_lowres'
        if not os.path.exists(fwdtransforms_save_dir):
            os.mkdir(fwdtransforms_save_dir)
        for source_path in fwdtransformlist:
            source_file = source_path.split('/')[-1]
            target_path = os.path.join(fwdtransforms_save_dir, source_file)
            copyfile(source_path, target_path)

    ##################################
    ### Apply warp params to mimic ###
    ##################################

    if mimic_path is not None:
        mimic_moco = ants.apply_transforms(fixed, mimic, moco['fwdtransforms'])

    ############
    ### Save ###
    ############

    # NOT SAVING MIMIC <------ MAY NEED TO CHANGE
    if flip_X:
        save_file = os.path.join(save_directory,
                                 moving_fly + '_m' + '-to-' + fixed_fly)
        #save_file = os.path.join(save_directory, mimic_fly + '_m' + '-to-' + fixed_fly + '.nii')
    else:
        save_file = os.path.join(save_directory,
                                 moving_fly + '-to-' + fixed_fly)
        #save_file = os.path.join(save_directory, mimic_fly + '-to-' + fixed_fly + '.nii')
    #nib.Nifti1Image(mimic_moco.numpy(), np.eye(4)).to_filename(save_file)
    if low_res:
        save_file += '_lowres'
    save_file += '.nii'
    nib.Nifti1Image(moco['warpedmovout'].numpy(),
                    np.eye(4)).to_filename(save_file)
Пример #25
0
def desikan_killiany_tourville_labeling(t1,
                                        do_preprocessing=True,
                                        return_probability_images=False,
                                        antsxnet_cache_directory=None,
                                        verbose=False):
    """
    Cortical and deep gray matter labeling using Desikan-Killiany-Tourville

    Perform DKT labeling using deep learning

    The labeling is as follows:

    Inner labels:
    Label 0: background
    Label 4: left lateral ventricle
    Label 5: left inferior lateral ventricle
    Label 6: left cerebellem exterior
    Label 7: left cerebellum white matter
    Label 10: left thalamus proper
    Label 11: left caudate
    Label 12: left putamen
    Label 13: left pallidium
    Label 15: 4th ventricle
    Label 16: brain stem
    Label 17: left hippocampus
    Label 18: left amygdala
    Label 24: CSF
    Label 25: left lesion
    Label 26: left accumbens area
    Label 28: left ventral DC
    Label 30: left vessel
    Label 43: right lateral ventricle
    Label 44: right inferior lateral ventricle
    Label 45: right cerebellum exterior
    Label 46: right cerebellum white matter
    Label 49: right thalamus proper
    Label 50: right caudate
    Label 51: right putamen
    Label 52: right palladium
    Label 53: right hippocampus
    Label 54: right amygdala
    Label 57: right lesion
    Label 58: right accumbens area
    Label 60: right ventral DC
    Label 62: right vessel
    Label 72: 5th ventricle
    Label 85: optic chasm
    Label 91: left basal forebrain
    Label 92: right basal forebrain
    Label 630: cerebellar vermal lobules I-V
    Label 631: cerebellar vermal lobules VI-VII
    Label 632: cerebellar vermal lobules VIII-X

    Outer labels:
    Label 1002: left caudal anterior cingulate
    Label 1003: left caudal middle frontal
    Label 1005: left cuneus
    Label 1006: left entorhinal
    Label 1007: left fusiform
    Label 1008: left inferior parietal
    Label 1009: left inferior temporal
    Label 1010: left isthmus cingulate
    Label 1011: left lateral occipital
    Label 1012: left lateral orbitofrontal
    Label 1013: left lingual
    Label 1014: left medial orbitofrontal
    Label 1015: left middle temporal
    Label 1016: left parahippocampal
    Label 1017: left paracentral
    Label 1018: left pars opercularis
    Label 1019: left pars orbitalis
    Label 1020: left pars triangularis
    Label 1021: left pericalcarine
    Label 1022: left postcentral
    Label 1023: left posterior cingulate
    Label 1024: left precentral
    Label 1025: left precuneus
    Label 1026: left rostral anterior cingulate
    Label 1027: left rostral middle frontal
    Label 1028: left superior frontal
    Label 1029: left superior parietal
    Label 1030: left superior temporal
    Label 1031: left supramarginal
    Label 1034: left transverse temporal
    Label 1035: left insula
    Label 2002: right caudal anterior cingulate
    Label 2003: right caudal middle frontal
    Label 2005: right cuneus
    Label 2006: right entorhinal
    Label 2007: right fusiform
    Label 2008: right inferior parietal
    Label 2009: right inferior temporal
    Label 2010: right isthmus cingulate
    Label 2011: right lateral occipital
    Label 2012: right lateral orbitofrontal
    Label 2013: right lingual
    Label 2014: right medial orbitofrontal
    Label 2015: right middle temporal
    Label 2016: right parahippocampal
    Label 2017: right paracentral
    Label 2018: right pars opercularis
    Label 2019: right pars orbitalis
    Label 2020: right pars triangularis
    Label 2021: right pericalcarine
    Label 2022: right postcentral
    Label 2023: right posterior cingulate
    Label 2024: right precentral
    Label 2025: right precuneus
    Label 2026: right rostral anterior cingulate
    Label 2027: right rostral middle frontal
    Label 2028: right superior frontal
    Label 2029: right superior parietal
    Label 2030: right superior temporal
    Label 2031: right supramarginal
    Label 2034: right transverse temporal
    Label 2035: right insula

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * denoising,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    do_preprocessing = True

    Arguments
    ---------
    t1 : ANTsImage
        raw or preprocessed 3-D T1-weighted brain image.

    do_preprocessing : boolean
        See description above.

    return_probability_images : boolean
        Whether to return the two sets of probability images for the inner and outer
        labels.

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    List consisting of the segmentation image and probability images for
    each label.

    Example
    -------
    >>> image = ants.image_read("t1.nii.gz")
    >>> flash = desikan_killiany_tourville_labeling(image)
    """

    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import categorical_focal_loss
    from ..utilities import preprocess_brain_image
    from ..utilities import crop_image_center

    if t1.dimension != 3:
        raise ValueError("Image dimension must be 3.")

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    ################################
    #
    # Preprocess images
    #
    ################################

    t1_preprocessed = t1
    if do_preprocessing == True:
        t1_preprocessing = preprocess_brain_image(
            t1,
            truncate_intensity=(0.01, 0.99),
            do_brain_extraction=True,
            template="croppedMni152",
            template_transform_type="AffineFast",
            do_bias_correction=True,
            do_denoising=True,
            antsxnet_cache_directory=antsxnet_cache_directory,
            verbose=verbose)
        t1_preprocessed = t1_preprocessing[
            "preprocessed_image"] * t1_preprocessing['brain_mask']

    ################################
    #
    # Download spatial priors for outer model
    #
    ################################

    spatial_priors_file_name_path = get_antsxnet_data(
        "priorDktLabels", antsxnet_cache_directory=antsxnet_cache_directory)
    spatial_priors = ants.image_read(spatial_priors_file_name_path)
    priors_image_list = ants.ndimage_to_list(spatial_priors)

    ################################
    #
    # Build outer model and load weights
    #
    ################################

    template_size = (96, 112, 96)
    labels = (0, 1002, 1003, *tuple(range(1005, 1032)), 1034, 1035, 2002, 2003,
              *tuple(range(2005, 2032)), 2034, 2035)
    channel_size = 1 + len(priors_image_list)

    unet_model = create_unet_model_3d((*template_size, channel_size),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=16,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = None
    weights_file_name = get_pretrained_network(
        "dktOuterWithSpatialPriors",
        antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Outer model Prediction.")

    downsampled_image = ants.resample_image(t1_preprocessed,
                                            template_size,
                                            use_voxels=True,
                                            interp_type=0)
    image_array = downsampled_image.numpy()
    image_array = (image_array - image_array.mean()) / image_array.std()

    batchX = np.zeros((1, *template_size, channel_size))
    batchX[0, :, :, :, 0] = image_array

    for i in range(len(priors_image_list)):
        resampled_prior_image = ants.resample_image(priors_image_list[i],
                                                    template_size,
                                                    use_voxels=True,
                                                    interp_type=0)
        batchX[0, :, :, :, i + 1] = resampled_prior_image.numpy()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = downsampled_image.origin
    spacing = downsampled_image.spacing
    direction = downsampled_image.direction

    inner_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        resampled_image = ants.resample_image(probability_image,
                                              t1_preprocessed.shape,
                                              use_voxels=True,
                                              interp_type=0)
        if do_preprocessing == True:
            inner_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=resampled_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            inner_probability_images.append(resampled_image)

    image_matrix = ants.image_list_to_matrix(inner_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    dkt_label_image = ants.image_clone(segmentation_image)
    for i in range(len(labels)):
        dkt_label_image[segmentation_image == i] = labels[i]

    ################################
    #
    # Build inner model and load weights
    #
    ################################

    template_size = (160, 192, 160)
    labels = (0, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 30,
              43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 58, 60, 91, 92, 630, 631,
              632)

    unet_model = create_unet_model_3d((*template_size, 1),
                                      number_of_outputs=len(labels),
                                      number_of_layers=4,
                                      number_of_filters_at_base_layer=8,
                                      dropout_rate=0.0,
                                      convolution_kernel_size=(3, 3, 3),
                                      deconvolution_kernel_size=(2, 2, 2),
                                      weight_decay=1e-5,
                                      add_attention_gating=True)

    weights_file_name = get_pretrained_network(
        "dktInner", antsxnet_cache_directory=antsxnet_cache_directory)
    unet_model.load_weights(weights_file_name)

    ################################
    #
    # Do prediction and normalize to native space
    #
    ################################

    if verbose == True:
        print("Prediction.")

    cropped_image = ants.crop_indices(t1_preprocessed, (12, 14, 0),
                                      (172, 206, 160))

    batchX = np.expand_dims(cropped_image.numpy(), axis=0)
    batchX = np.expand_dims(batchX, axis=-1)
    batchX = (batchX - batchX.mean()) / batchX.std()

    predicted_data = unet_model.predict(batchX, verbose=verbose)

    origin = cropped_image.origin
    spacing = cropped_image.spacing
    direction = cropped_image.direction

    outer_probability_images = list()
    for i in range(len(labels)):
        probability_image = \
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
            origin=origin, spacing=spacing, direction=direction)
        if i > 0:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0)
        else:
            decropped_image = ants.decrop_image(probability_image,
                                                t1_preprocessed * 0 + 1)

        if do_preprocessing == True:
            outer_probability_images.append(
                ants.apply_transforms(
                    fixed=t1,
                    moving=decropped_image,
                    transformlist=t1_preprocessing['template_transforms']
                    ['invtransforms'],
                    whichtoinvert=[True],
                    interpolator="linear",
                    verbose=verbose))
        else:
            outer_probability_images.append(decropped_image)

    image_matrix = ants.image_list_to_matrix(outer_probability_images,
                                             t1 * 0 + 1)
    segmentation_matrix = np.argmax(image_matrix, axis=0)
    segmentation_image = ants.matrix_to_images(
        np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

    ################################
    #
    # Incorporate the inner model results into the final label image.
    # Note that we purposely prioritize the inner label results.
    #
    ################################

    for i in range(len(labels)):
        if labels[i] > 0:
            dkt_label_image[segmentation_image == i] = labels[i]

    if return_probability_images == True:
        return_dict = {
            'segmentation_image': dkt_label_image,
            'inner_probability_images': inner_probability_images,
            'outer_probability_images': outer_probability_images
        }
        return (return_dict)
    else:
        return (dkt_label_image)
Пример #26
0
# Register all training volumes to 0007.nii.gz. No resizing in this version

import ants
import os

# Constants for path names
FIXED_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Train/img/0007.nii.gz"
OLD_TEST_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Testing/img/"
NEW_TEST_IMG = "/content/drive/My Drive/cs8395_deep_learning/assignment3/data/Testing/img_registered_syn/"
fixed = ants.image_read(FIXED_IMG)
fixed = ants.resample_image(fixed, [256, 256, 80], True, 1)

# Register all the training images
for file_name in os.listdir(OLD_TEST_IMG):
    moving_image = ants.image_read(OLD_TEST_IMG + file_name)
    # Downsample for faster registration
    moving_image = ants.resample_image(moving_image, [256, 256, 80], True, 1)
    print("Registering ", file_name)
    transform = ants.registration(fixed=fixed,
                                  moving=moving_image,
                                  type_of_transform='SyN')
    print(transform)
    transformed_image = ants.apply_transforms(
        fixed=fixed,
        moving=moving_image,
        transformlist=transform['fwdtransforms'],
        interpolator='nearestNeighbor')
    # Upsample again
    transformed_image = ants.resample_image(transformed_image, [512, 512, 160],
                                            True, 1)
    transformed_image.to_file(NEW_TEST_IMG + file_name)
Пример #27
0
def lung_extraction(image,
                    modality="proton",
                    antsxnet_cache_directory=None,
                    verbose=False):

    """
    Perform proton or ct lung extraction using U-net.

    Arguments
    ---------
    image : ANTsImage
        input image

    modality : string
        Modality image type.  Options include "ct", "proton", "protonLobes", 
        "maskLobes", and "ventilation".

    antsxnet_cache_directory : string
        Destination directory for storing the downloaded template and model weights.
        Since these can be resused, if is None, these data will be downloaded to a
        ~/.keras/ANTsXNet/.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    Dictionary of ANTs segmentation and probability images.

    Example
    -------
    >>> output = lung_extraction(lung_image, modality="proton")
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import get_antsxnet_data
    from ..utilities import pad_or_crop_image_to_size

    if image.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    image_mods = [modality]
    channel_size = len(image_mods)

    weights_file_name = None
    unet_model = None

    if modality == "proton":
        weights_file_name = get_pretrained_network("protonLungMri",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left_lung", "right_lung")
        number_of_classification_labels = len(classes)

        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels,
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(7, 7, 5), deconvolution_kernel_size=(7, 7, 5))
        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)

        batchX = np.expand_dims(warped_image.numpy(), axis=0)
        batchX = np.expand_dims(batchX, axis=-1)
        batchX = (batchX - batchX.mean()) / batchX.std()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            probability_images_array.append(
            ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability mask to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images_array}
        return(return_dict)

    if modality == "protonLobes" or modality == "maskLobes":
        reorient_template_file_name_path = get_antsxnet_data("protonLungTemplate",
            antsxnet_cache_directory=antsxnet_cache_directory)
        reorient_template = ants.image_read(reorient_template_file_name_path)

        resampled_image_size = reorient_template.shape

        spatial_priors_file_name_path = get_antsxnet_data("protonLobePriors",
            antsxnet_cache_directory=antsxnet_cache_directory)
        spatial_priors = ants.image_read(spatial_priors_file_name_path)
        priors_image_list = ants.ndimage_to_list(spatial_priors)

        channel_size = 1 + len(priors_image_list)
        number_of_classification_labels = 1 + len(priors_image_list)

        unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification", 
            number_of_filters_at_base_layer=16, number_of_layers=4,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            dropout_rate=0.0, weight_decay=0, additional_options=("attentionGating",))

        if modality == "protonLobes":
            penultimate_layer = unet_model.layers[-2].output
            outputs2 = Conv3D(filters=1,
                            kernel_size=(1, 1, 1),
                            activation='sigmoid',
                            kernel_regularizer=regularizers.l2(0.0))(penultimate_layer)
            unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, outputs2])
            weights_file_name = get_pretrained_network("protonLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)
        else:
            weights_file_name = get_pretrained_network("maskLobes",
                antsxnet_cache_directory=antsxnet_cache_directory)

        unet_model.load_weights(weights_file_name)

        if verbose == True:
            print("Lung extraction:  normalizing image to the template.")

        center_of_mass_template = ants.get_center_of_mass(reorient_template * 0 + 1)
        center_of_mass_image = ants.get_center_of_mass(image * 0 + 1)
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_template), translation=translation)
        warped_image = ants.apply_ants_transform_to_image(xfrm, image, reorient_template)
        warped_array = warped_image.numpy()
        if modality == "protonLobes":
            warped_array = (warped_array - warped_array.mean()) / warped_array.std()
        else:
            warped_array[warped_array != 0] = 1
       
        batchX = np.zeros((1, *warped_array.shape, channel_size))
        batchX[0,:,:,:,0] = warped_array
        for i in range(len(priors_image_list)):
            batchX[0,:,:,:,i+1] = priors_image_list[i].numpy()

        predicted_data = unet_model.predict(batchX, verbose=int(verbose))

        origin = warped_image.origin
        spacing = warped_image.spacing
        direction = warped_image.direction

        probability_images_array = list()
        for i in range(number_of_classification_labels):
            if modality == "protonLobes":
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0][0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))
            else:
                probability_images_array.append(
                    ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]),
                    origin=origin, spacing=spacing, direction=direction))

        if verbose == True:
            print("Lung extraction:  renormalize probability images to native space.")

        for i in range(number_of_classification_labels):
            probability_images_array[i] = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_images_array[i], image)

        image_matrix = ants.image_list_to_matrix(probability_images_array, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        if modality == "protonLobes":
            whole_lung_mask = ants.from_numpy(np.squeeze(predicted_data[1][0, :, :, :, 0]),
                origin=origin, spacing=spacing, direction=direction)
            whole_lung_mask = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), whole_lung_mask, image)

            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array,
                           'whole_lung_mask_image' : whole_lung_mask}
            return(return_dict)
        else:
            return_dict = {'segmentation_image' : segmentation_image,
                           'probability_images' : probability_images_array}
            return(return_dict)


    elif modality == "ct":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess CT image.")

        def closest_simplified_direction_matrix(direction):
            closest = np.floor(np.abs(direction) + 0.5)
            closest[direction < 0] *= -1.0
            return closest

        simplified_direction = closest_simplified_direction_matrix(image.direction)

        reference_image_size = (128, 128, 128)

        ct_preprocessed = ants.resample_image(image, reference_image_size, use_voxels=True, interp_type=0)
        ct_preprocessed[ct_preprocessed < -1000] = -1000
        ct_preprocessed[ct_preprocessed > 400] = 400
        ct_preprocessed.set_direction(simplified_direction)
        ct_preprocessed.set_origin((0, 0, 0))
        ct_preprocessed.set_spacing((1, 1, 1))

        ################################
        #
        # Reorient image
        #
        ################################

        reference_image = ants.make_image(reference_image_size,
                                          voxval=0,
                                          spacing=(1, 1, 1),
                                          origin=(0, 0, 0),
                                          direction=np.identity(3))
        center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1))
        center_of_mass_image = np.floor(ants.get_center_of_mass(ct_preprocessed * 0 + 1))
        translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference)
        xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
            center=np.asarray(center_of_mass_reference), translation=translation)
        ct_preprocessed = ((ct_preprocessed - ct_preprocessed.min()) /
            (ct_preprocessed.max() - ct_preprocessed.min()))
        ct_preprocessed_warped = ants.apply_ants_transform_to_image(
            xfrm, ct_preprocessed, reference_image, interpolation="nearestneighbor")
        ct_preprocessed_warped = ((ct_preprocessed_warped - ct_preprocessed_warped.min()) /
            (ct_preprocessed_warped.max() - ct_preprocessed_warped.min())) - 0.5

        ################################
        #
        # Build models and load weights
        #
        ################################

        if verbose == True:
            print("Build model and load weights.")

        weights_file_name = get_pretrained_network("lungCtWithPriorsSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)

        classes = ("background", "left lung", "right lung", "airways")
        number_of_classification_labels = len(classes)

        luna16_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("luna16LungPriors")))
        for i in range(len(luna16_priors)):
            luna16_priors[i] = ants.resample_image(luna16_priors[i], reference_image_size, use_voxels=True)
        channel_size = len(luna16_priors) + 1

        unet_model = create_unet_model_3d((*reference_image_size, channel_size),
            number_of_outputs=number_of_classification_labels, mode="classification",
            number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0,
            convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2),
            weight_decay=1e-5, additional_options=("attentionGating",))
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("Prediction.")

        batchX = np.zeros((1, *reference_image_size, channel_size))
        batchX[:,:,:,:,0] = ct_preprocessed_warped.numpy()

        for i in range(len(luna16_priors)):
            batchX[:,:,:,:,i+1] = luna16_priors[i].numpy() - 0.5

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(number_of_classification_labels):
            if verbose == True:
                print("Reconstructing image", classes[i])
            probability_image = ants.from_numpy(np.squeeze(predicted_data[:,:,:,:,i]),
                origin=ct_preprocessed_warped.origin, spacing=ct_preprocessed_warped.spacing,
                direction=ct_preprocessed_warped.direction)
            probability_image = ants.apply_ants_transform_to_image(
                ants.invert_ants_transform(xfrm), probability_image, ct_preprocessed)
            probability_image = ants.resample_image(probability_image,
               resample_params=image.shape, use_voxels=True, interp_type=0)
            probability_image = ants.copy_image_info(image, probability_image)
            probability_images.append(probability_image)

        image_matrix = ants.image_list_to_matrix(probability_images, image * 0 + 1)
        segmentation_matrix = np.argmax(image_matrix, axis=0)
        segmentation_image = ants.matrix_to_images(
            np.expand_dims(segmentation_matrix, axis=0), image * 0 + 1)[0]

        return_dict = {'segmentation_image' : segmentation_image,
                       'probability_images' : probability_images}
        return(return_dict)

    elif modality == "ventilation":

        ################################
        #
        # Preprocess image
        #
        ################################

        if verbose == True:
            print("Preprocess ventilation image.")

        template_size = (256, 256)

        image_modalities = ("Ventilation",)
        channel_size = len(image_modalities)

        preprocessed_image = (image - image.mean()) / image.std()
        ants.set_direction(preprocessed_image, np.identity(3))

        ################################
        #
        # Build models and load weights
        #
        ################################

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs=1, mode='sigmoid',
            number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0,
            convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2),
            weight_decay=0)

        if verbose == True:
            print("Whole lung mask: retrieving model weights.")

        weights_file_name = get_pretrained_network("wholeLungMaskFromVentilation",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        spacing = ants.get_spacing(preprocessed_image)
        dimensions_to_predict = (spacing.index(max(spacing)),)

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = ventilation_slice.numpy()
                slice_count += 1

        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        probability_image = ants.image_clone(image) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)

            prediction_per_dimension = prediction[which_batch_slices,:,:,0]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(image,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                image.shape))
            probability_image = probability_image + (prediction_image - probability_image) / (d + 1)

            current_start_slice = current_end_slice + 1

        return(probability_image)

    else:
        return ValueError("Unrecognized modality.")
Пример #28
0
import matplotlib.pyplot as plt
import ants
import os
from pydicom import dcmread
from pydicom.data import get_testdata_file
from nipype import Node, Workflow
from nipype.interfaces.ants import N4BiasFieldCorrection
import ants
import nibabel as nib
import SimpleITK as sitk
from Utils import *

images_short = ants.image_read(
    "/media/sf_VB_Folder/s3D_IR-TFE_2BH_SENSE-1401_s3D_IR-TFE_2_BH_SENSE_20160711081415_1401_t682.nii"
)
images_short = ants.resample_image(images_short, [1.1875, 1.1875, 1.1875])
images_hLong = ants.image_read(
    "s3D_IR-TFE_BH_20slSENSE-1201_s3D_IR-TFE_BH_20sl_SENSE_20160711081415_1201_t681.nii"
)
# images_hLong = ants.resample_image(images_hLong,[320,320,320],True)
images_vLong = ants.image_read(
    "/media/sf_VB_Folder/s3D_IR-TFE_BH_20slSENSE-1301_s3D_IR-TFE_BH_20sl_SENSE_20160711081415_1301_t681.nii"
)
# images_vLong = ants.resample_image(images_vLong,[320,320,320],True)

shortO = ants.to_nibabel(images_short)
vLongO = ants.to_nibabel(images_vLong)
hLongO = ants.to_nibabel(images_hLong)
shortO.set_data_dtype('int16')
vLongO.set_data_dtype('int16')
hLongO.set_data_dtype('int16')
Пример #29
0
def preprocess_images(
    atlas_id: int,
    atlas_csf_id: int,
    atlas_grey_id: int,
    atlas_white_id: int,
    dataset_id: int,
    replace: bool = False,
    downsample: float = 3.0,
):
    atlas = models.Atlas.objects.get(pk=atlas_id)
    atlas_csf = models.Atlas.objects.get(pk=atlas_csf_id)
    atlas_grey = models.Atlas.objects.get(pk=atlas_grey_id)
    atlas_white = models.Atlas.objects.get(pk=atlas_white_id)
    dataset = models.Dataset.objects.get(pk=dataset_id)

    print('Downloading atlas files')
    with NamedTemporaryFile(
            suffix='atlas.nii') as tmp, atlas.blob.open() as blob:
        for chunk in blob.chunks():
            tmp.write(chunk)
        atlas_img = ants.image_read(tmp.name)

    with NamedTemporaryFile(
            suffix='atlas_csf.nii') as tmp, atlas_csf.blob.open() as blob:
        for chunk in blob.chunks():
            tmp.write(chunk)
        atlas_csf_img = ants.image_read(tmp.name)

    with NamedTemporaryFile(
            suffix='atlas_grey.nii') as tmp, atlas_grey.blob.open() as blob:
        for chunk in blob.chunks():
            tmp.write(chunk)
        atlas_grey_img = ants.image_read(tmp.name)

    with NamedTemporaryFile(
            suffix='atlas_white.nii') as tmp, atlas_white.blob.open() as blob:
        for chunk in blob.chunks():
            tmp.write(chunk)
        atlas_white_img = ants.image_read(tmp.name)

    print('Creating mask')
    priors = [atlas_csf_img, atlas_grey_img, atlas_white_img]
    mask = priors[0].copy()
    mask_view = mask.view()
    for i in range(1, len(priors)):
        mask_view[priors[i].numpy() > 0] = 1
    mask_view[mask_view > 0] = 1

    for image in dataset.images.all():
        if replace:
            _delete_preprocessing_artifacts(image)
        elif _already_preprocessed(image):
            continue
        with NamedTemporaryFile(
                suffix=image.name) as tmp, image.blob.open() as blob:
            for chunk in blob.chunks():
                tmp.write(chunk)
            input_img = ants.image_read(tmp.name)

        print(f'Running N4 bias correction: {image.name}')
        im_n4 = ants.n4_bias_field_correction(input_img)
        del input_img
        print(f'Running registration: {image.name}')
        reg = ants.registration(atlas_img, im_n4)
        del im_n4
        jac_img = ants.create_jacobian_determinant_image(
            atlas_img, reg['fwdtransforms'][0], 1)
        jac_img = jac_img.apply(np.abs)

        reg_model = models.RegisteredImage(source_image=image, atlas=atlas)
        reg_img = reg['warpedmovout']
        with NamedTemporaryFile(suffix='registered.nii') as tmp:
            ants.image_write(reg_img, tmp.name)
            reg_model.blob = File(tmp, name='registered.nii')
            reg_model.save()

        jac_model = models.JacobianImage(source_image=image, atlas=atlas)
        with NamedTemporaryFile(suffix='jacobian.nii') as tmp:
            ants.image_write(jac_img, tmp.name)
            jac_model.blob = File(tmp, name='jacobian.nii')
            jac_model.save()

        print(f'Running segmentation: {image.name}')
        seg = ants.prior_based_segmentation(reg_img, priors, mask)
        del reg_img

        seg_model = models.SegmentedImage(source_image=image, atlas=atlas)
        with NamedTemporaryFile(suffix='segmented.nii') as tmp:
            ants.image_write(seg['segmentation'], tmp.name)
            seg_model.blob = File(tmp, name='segmented.nii')
            seg_model.save()

        print(f'Creating feature image: {image.name}')
        seg_img_view = seg['segmentation'].view()
        feature_img = seg['segmentation'].copy()
        feature_img_view = feature_img.view()
        feature_img_view.fill(0)
        feature_img_view[seg_img_view == 3] = 1  # 3 is white matter label

        intensity_img_view = jac_img.view()
        feature_img_view *= intensity_img_view

        if downsample > 1:
            shape = np.round(np.asarray(feature_img.shape) / downsample)
            feature_img = ants.resample_image(feature_img, shape, True)

        feature_model = models.FeatureImage(source_image=image,
                                            atlas=atlas,
                                            downsample_factor=downsample)
        with NamedTemporaryFile(suffix='feature.nii') as tmp:
            ants.image_write(feature_img, tmp.name)
            feature_model.blob = File(tmp, name='feature.nii')
            feature_model.save()

    dataset.preprocessing_complete = True
    dataset.save()
Пример #30
0
def ew_david(flair,
             t1,
             do_preprocessing=True,
             do_slicewise=True,
             antsxnet_cache_directory=None,
             verbose=False):

    """
    Perform White matter hypterintensity probabilistic segmentation
    using deep learning

    Preprocessing on the training data consisted of:
       * n4 bias correction,
       * brain extraction, and
       * affine registration to MNI.
    The input T1 should undergo the same steps.  If the input T1 is the raw
    T1, these steps can be performed by the internal preprocessing, i.e. set
    \code{doPreprocessing = TRUE}

    Arguments
    ---------
    flair : ANTsImage
        input 3-D FLAIR brain image (not skull-stripped).

    t1 : ANTsImage
        input 3-D T1 brain image (not skull-stripped).

    do_preprocessing : boolean
        perform n4 bias correction?

    do_slicewise : boolean
        apply 2-D modal along direction of maximal slice thickness.

    verbose : boolean
        Print progress to the screen.

    Returns
    -------
    WMH segmentation probability image

    Example
    -------
    >>> image = ants.image_read("flair.nii.gz")
    >>> probability_mask = sysu_media_wmh_segmentation(image)
    """

    from ..architectures import create_unet_model_2d
    from ..architectures import create_unet_model_3d
    from ..utilities import get_pretrained_network
    from ..utilities import preprocess_brain_image
    from ..utilities import extract_image_patches
    from ..utilities import reconstruct_image_from_patches
    from ..utilities import pad_or_crop_image_to_size

    if flair.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if t1.dimension != 3:
        raise ValueError( "Image dimension must be 3." )

    if antsxnet_cache_directory == None:
        antsxnet_cache_directory = "ANTsXNet"

    if do_slicewise == False:

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=True,
                template="croppedMni152",
                template_transform_type="AffineFast",
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = ants.apply_transforms(fixed=t1_preprocessed,
                moving=flair_preprocessing["preprocessed_image"],
                transformlist=t1_preprocessing['template_transforms']['fwdtransforms'])
            flair_preprocessed = flair_preprocessed * t1_preprocessing['brain_mask']

        ################################
        #
        # Build model and load weights
        #
        ################################

        patch_size = (112, 112, 112)
        stride_length = (t1_preprocessed.shape[0] - patch_size[0],
                        t1_preprocessed.shape[1] - patch_size[1],
                        t1_preprocessed.shape[2] - patch_size[2])

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_3d((*patch_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 16, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=False, add_attention_gating=True)

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Do prediction and normalize to native space
        #
        ################################

        if verbose == True:
            print("ew_david:  prediction.")

        batchX = np.zeros((8, *patch_size, channel_size))

        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()
        t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,0] = t1_patches

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size,
                                            max_number_of_patches="all", stride_length=stride_length,
                                            return_as_array=True)
        batchX[:,:,:,:,1] = flair_patches

        predicted_data = unet_model.predict(batchX, verbose=verbose)

        probability_images = list()
        for i in range(len(labels)):
            print("Reconstructing image", classes[i])
            reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i],
                domain_image=t1_preprocessed, stride_length=stride_length)

            if do_preprocessing == True:
                probability_images.append(ants.apply_transforms(fixed=t1,
                    moving=reconstructed_image,
                    transformlist=t1_preprocessing['template_transforms']['invtransforms'],
                    whichtoinvert=[True], interpolator="linear", verbose=verbose))
            else:
                probability_images.append(reconstructed_image)

        return(probability_images[1])

    else:  # do_slicewise

        ################################
        #
        # Preprocess images
        #
        ################################

        t1_preprocessed = t1
        t1_preprocessing = None
        if do_preprocessing == True:
            t1_preprocessing = preprocess_brain_image(t1,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            t1_preprocessed = t1_preprocessing["preprocessed_image"]

        flair_preprocessed = flair
        if do_preprocessing == True:
            flair_preprocessing = preprocess_brain_image(flair,
                truncate_intensity=(0.01, 0.99),
                do_brain_extraction=False,
                do_bias_correction=True,
                do_denoising=False,
                antsxnet_cache_directory=antsxnet_cache_directory,
                verbose=verbose)
            flair_preprocessed = flair_preprocessing["preprocessed_image"]

        resampling_params = list(ants.get_spacing(flair_preprocessed))

        do_resampling = False
        for d in range(len(resampling_params)):
            if resampling_params[d] < 0.8:
                resampling_params[d] = 1.0
                do_resampling = True

        resampling_params = tuple(resampling_params)

        if do_resampling:
            flair_preprocessed = ants.resample_image(flair_preprocessed, resampling_params, use_voxels=False, interp_type=0)
            t1_preprocessed = ants.resample_image(t1_preprocessed, resampling_params, use_voxels=False, interp_type=0)

        flair_preprocessed = (flair_preprocessed - flair_preprocessed.mean()) / flair_preprocessed.std()
        t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std()

        ################################
        #
        # Build model and load weights
        #
        ################################

        template_size = (256, 256)

        classes = ("background", "wmh" )
        number_of_classification_labels = len(classes)
        labels = (0, 1)

        image_modalities = ("T1", "FLAIR")
        channel_size = len(image_modalities)

        unet_model = create_unet_model_2d((*template_size, channel_size),
            number_of_outputs = number_of_classification_labels,
            number_of_layers = 4, number_of_filters_at_base_layer = 32, dropout_rate = 0.0,
            convolution_kernel_size = (3, 3), deconvolution_kernel_size = (2, 2),
            weight_decay = 1e-5, nn_unet_activation_style=True, add_attention_gating=True)

        if verbose == True:
            print("ewDavid:  retrieving model weights.")

        weights_file_name = get_pretrained_network("ewDavidWmhSegmentationSlicewiseWeights",
            antsxnet_cache_directory=antsxnet_cache_directory)
        unet_model.load_weights(weights_file_name)

        ################################
        #
        # Extract slices
        #
        ################################

        use_coarse_slices_only = True

        spacing = ants.get_spacing(flair_preprocessed)
        dimensions_to_predict = (spacing.index(max(spacing)),)
        if use_coarse_slices_only == False:
            dimensions_to_predict = list(range(3))

        total_number_of_slices = 0
        for d in range(len(dimensions_to_predict)):
            total_number_of_slices += flair_preprocessed.shape[dimensions_to_predict[d]]

        batchX = np.zeros((total_number_of_slices, *template_size, channel_size))

        slice_count = 0
        for d in range(len(dimensions_to_predict)):
            number_of_slices = flair_preprocessed.shape[dimensions_to_predict[d]]

            if verbose == True:
                print("Extracting slices for dimension ", dimensions_to_predict[d], ".")

            for i in range(number_of_slices):
                flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,0] = flair_slice.numpy()

                t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed, dimensions_to_predict[d], i), template_size)
                batchX[slice_count,:,:,1] = t1_slice.numpy()

                slice_count += 1


        ################################
        #
        # Do prediction and then restack into the image
        #
        ################################

        if verbose == True:
            print("Prediction.")

        prediction = unet_model.predict(batchX, verbose=verbose)

        permutations = list()
        permutations.append((0, 1, 2))
        permutations.append((1, 0, 2))
        permutations.append((1, 2, 0))

        prediction_image_average = ants.image_clone(flair_preprocessed) * 0

        current_start_slice = 0
        for d in range(len(dimensions_to_predict)):
            current_end_slice = current_start_slice + flair_preprocessed.shape[dimensions_to_predict[d]] - 1
            which_batch_slices = range(current_start_slice, current_end_slice)
            prediction_per_dimension = prediction[which_batch_slices,:,:,1]
            prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]])
            prediction_image = ants.copy_image_info(flair_preprocessed,
                pad_or_crop_image_to_size(ants.from_numpy(prediction_array),
                flair_preprocessed.shape))
            prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1)

            current_start_slice = current_end_slice + 1

        if do_resampling:
            prediction_image_average = ants.resample_image_to_target(prediction_image_average, flair)

        return(prediction_image_average)