Example #1
0
def nyul_normalize(img_dir, mask_dir=None, output_dir=None, standard_hist=None, write_to_disk=True):
    """
    Use Nyul and Udupa method ([1,2]) to normalize the intensities of a set of MR images

    Args:
        img_dir (str): directory containing MR images
        img_dir (str): directory containing masks for MR images
        output_dir (str): directory to save images if you do not want them saved in
            same directory as data_dir
        standard_hist (str): path to output or use standard histogram landmarks
        write_to_disk (bool): write the normalized data to disk or nah

    Returns:
        normalized (np.ndarray): last normalized image from img_dir

    References:
        [1] N. Laszlo G and J. K. Udupa, “On Standardizing the MR Image
            Intensity Scale,” Magn. Reson. Med., vol. 42, pp. 1072–1081,
            1999.
        [2] M. Shah, Y. Xiao, N. Subbanna, S. Francis, D. L. Arnold,
            D. L. Collins, and T. Arbel, “Evaluating intensity
            normalization on MRIs of human brain with multiple sclerosis,”
            Med. Image Anal., vol. 15, no. 2, pp. 267–282, 2011.
    """
    input_files = io.glob_nii(img_dir)
    if output_dir is None:
        out_fns = [None] * len(input_files)
    else:
        out_fns = []
        for fn in input_files:
            _, base, ext = io.split_filename(fn)
            out_fns.append(os.path.join(output_dir, base + '_hm' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

    mask_files = [None] * len(input_files) if mask_dir is None else io.glob_nii(mask_dir)

    if standard_hist is None:
        logger.info('Learning standard scale for the set of images')
        standard_scale, percs = train(input_files, mask_files)
    elif not os.path.isfile(standard_hist):
        logger.info('Learning standard scale for the set of images')
        standard_scale, percs = train(input_files, mask_files)
        np.save(standard_hist, np.vstack((standard_scale, percs)))
    else:
        logger.info('Loading standard scale ({}) for the set of images'.format(standard_hist))
        standard_scale, percs = np.load(standard_hist)

    for i, (img_fn, mask_fn, out_fn) in enumerate(zip(input_files, mask_files, out_fns)):
        _, base, _ = io.split_filename(img_fn)
        logger.info('Transforming image {} to standard scale ({:d}/{:d})'.format(base, i+1, len(input_files)))
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        normalized = do_hist_norm(img, percs, standard_scale, mask)
        if write_to_disk:
            io.save_nii(normalized, out_fn, is_nii=True)

    return normalized
def lsq_normalize(img_dir, mask_dir=None, output_dir=None, write_to_disk=True):
    """
    normalize intensities of a set of MR images by minimizing the squared distance
    between CSF, GM, and WM means within the set

    Args:
        img_dir (str): directory containing MR images
        mask_dir (str): directory containing masks for MR images
        output_dir (str): directory to save images if you do not want them saved in
            same directory as data_dir
        write_to_disk (bool): write the normalized data to disk or nah

    Returns:
        normalized (np.ndarray): last normalized image from img_dir
    """
    input_files = io.glob_nii(img_dir)
    if output_dir is None:
        out_fns = [None] * len(input_files)
    else:
        out_fns = []
        for fn in input_files:
            _, base, ext = io.split_filename(fn)
            out_fns.append(os.path.join(output_dir, base + '_lsq' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

    mask_files = [None] * len(
        input_files) if mask_dir is None else io.glob_nii(mask_dir)

    standard_tissue_means = None
    normalized = None
    for i, (img_fn, mask_fn,
            out_fn) in enumerate(zip(input_files, mask_files, out_fns)):
        _, base, _ = io.split_filename(img_fn)
        logger.info(
            'Transforming image {} to standard scale ({:d}/{:d})'.format(
                base, i + 1, len(input_files)))
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        tissue_mem = mask_util.fcm_class_mask(img, mask)
        if standard_tissue_means is None:
            csf_tissue_mask = find_tissue_mask(img, mask, tissue_type='csf')
            csf_normed_data = fcm_normalize(img, csf_tissue_mask).get_fdata()
            standard_tissue_means = calc_tissue_means(csf_normed_data,
                                                      tissue_mem)
            del csf_tissue_mask, csf_normed_data
        img_data = img.get_fdata()
        tissue_means = calc_tissue_means(img_data, tissue_mem)
        sf = find_scaling_factor(tissue_means, standard_tissue_means)
        logger.debug('Scaling factor for {}: {:0.3e}'.format(base, sf))
        normalized = nib.Nifti1Image(sf * img_data, img.affine, img.header)
        if write_to_disk:
            io.save_nii(normalized, out_fn, is_nii=True)

    return normalized
def process(image, brain_mask, args, logger):
    img = io.open_nii(image)
    mask = io.open_nii(brain_mask)
    dirname, base, ext = io.split_filename(image)
    if args.output_dir is not None:
        dirname = args.output_dir
        if not os.path.exists(dirname):
            logger.info('Making output directory: {}'.format(dirname))
            os.mkdir(dirname)
    if args.find_background_mask:
        bg_mask = background_mask(img)
        bgfile = os.path.join(dirname, base + '_bgmask' + ext)
        io.save_nii(bg_mask, bgfile, is_nii=True)
    if args.wm_peak is not None:
        logger.info('Loading WM peak: ', args.wm_peak)
        peak = float(np.load(args.wm_peak))
    else:
        peak = gmm_class_mask(img, brain_mask=mask, contrast=args.contrast)
        if args.save_wm_peak:
            np.save(os.path.join(dirname, base + '_wmpeak.npy'), peak)
    normalized = gmm.gmm_normalize(img, mask, args.norm_value, args.contrast,
                                   args.background_mask, peak)
    outfile = os.path.join(dirname, base + '_gmm' + ext)
    logger.info('Normalized image saved: {}'.format(outfile))
    io.save_nii(normalized, outfile, is_nii=True)
Example #4
0
def robex(img, out_mask, skull_stripped=False):
    """
    perform skull-stripping on the registered image using the
    ROBEX algorithm

    Args:
        img (str): path to image to skull strip
        out_mask (str): path to output mask file
        skull_stripped (bool): return the mask
            AND the skull-stripped image [default = False]

    Returns:
        mask (ants.ANTsImage): mask/skull-stripped image
    """

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore')
        _ = ROBEX.robex(img, outfile=out_mask)
    skull_stripped_img = ants.image_read(out_mask)
    mask = skull_stripped_img.get_mask(low_thresh=1)
    ants.image_write(mask, out_mask)
    if skull_stripped:
        # write the skull-stripped image to disk if desired (in addition to mask)
        dirname, base, _ = split_filename(out_mask)
        base = base.replace(
            'mask', 'stripped') if 'mask' in base else base + '_stripped'
        ants.image_write(skull_stripped_img,
                         os.path.join(dirname, base + '.nii.gz'))
    return mask
Example #5
0
def csf_mask_intersection(img_dir, masks=None, prob=1):
    """
    use all nifti T1w images in data_dir to create csf mask in common areas

    Args:
        img_dir (str): directory containing MR images to be normalized
        masks (str or ants.core.ants_image.ANTsImage): if images are not skull-stripped,
            then provide brain mask as either a corresponding directory or an individual mask
        prob (float): given all data, proportion of data labeled as csf to be
            used for intersection

    Returns:
        intersection (np.ndarray): binary mask of common csf areas for all provided imgs
    """
    if not (0 <= prob <= 1):
        raise NormalizationError(
            'prob must be between 0 and 1. {} given.'.format(prob))
    data = io.glob_nii(img_dir)
    masks = io.glob_nii(masks) if isinstance(masks,
                                             str) else [masks] * len(data)
    csf = []
    for i, (img, mask) in enumerate(zip(data, masks)):
        _, base, _ = io.split_filename(img)
        logger.info('Creating CSF mask for image {} ({:d}/{:d})'.format(
            base, i + 1, len(data)))
        imgn = ants.image_read(img)
        maskn = ants.image_read(mask) if isinstance(mask, str) else mask
        csf.append(csf_mask(imgn, maskn))
    csf_sum = reduce(
        add, csf)  # need to use reduce instead of sum b/c data structure
    intersection = np.zeros(csf_sum.shape)
    intersection[csf_sum >= np.floor(len(data) * prob)] = 1
    return intersection
def match_histograms(reference_path, apply_to_path):

    print('Start')

    for name in ['InPhase', 'OutPhase']:

        print(f'Working on: {name}')
        ref_files = list()
        for folder in os.listdir(reference_path):
            ref_files.append(f'{reference_path}/{folder}/{name}.nii')

        print('Training')
        mask_files = [None] * len(ref_files)
        standard_scale, percs = nyul.train(ref_files, mask_files)

        input_files, output_files = list(), list()
        for folder in os.listdir(apply_to_path):
            input_files.append(f'{apply_to_path}/{folder}/{name}.nii')
            output_files.append(f'{apply_to_path}/{folder}/{name}_HM.nii')

        print('Normalizing')
        for img_fn, out_fn in zip(input_files, output_files):
            print(img_fn, '->', out_fn)
            _, base, _ = io.split_filename(img_fn)
            img = io.open_nii(img_fn)
            normalized = nyul.do_hist_norm(img,
                                           percs,
                                           standard_scale,
                                           mask=None)
            io.save_nii(normalized, out_fn, is_nii=True)

    print('Done')
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        if not os.path.isdir(args.img_dir):
            raise ValueError(
                '(-i / --img-dir) argument needs to be a directory of NIfTI images.'
            )
        if args.mask_dir is not None:
            if not os.path.isdir(args.mask_dir):
                raise ValueError(
                    '(-m / --mask-dir) argument needs to be a directory of NIfTI images.'
                )

        img_fns = io.glob_nii(args.img_dir)
        if args.mask_dir is not None:
            mask_fns = io.glob_nii(args.mask_dir)
        else:
            mask_fns = [None] * len(img_fns)
        if not os.path.exists(args.output_dir):
            logger.info('Making Output Directory: {}'.format(args.output_dir))
            os.mkdir(args.output_dir)
        hard_seg = not args.memberships
        for i, (img_fn, mask_fn) in enumerate(zip(img_fns, mask_fns), 1):
            _, base, _ = io.split_filename(img_fn)
            logger.info('Creating Mask for Image: {}, ({:d}/{:d})'.format(
                base, i, len(img_fns)))
            img = io.open_nii(img_fn)
            mask = io.open_nii(mask_fn)
            tm = fcm_class_mask(img, mask,
                                hard_seg) if not args.gmm else gmm_class_mask(
                                    img, mask, 't1', False, hard_seg)
            tissue_mask = os.path.join(args.output_dir, base + '_tm')
            if args.memberships:
                classes = ('csf', 'gm', 'wm')
                for j, c in enumerate(classes):
                    io.save_nii(img, tissue_mask + '_' + c + '.nii.gz', tm[...,
                                                                           j])
            else:
                io.save_nii(img, tissue_mask + '.nii.gz', tm)
        return 0
    except Exception as e:
        logger.exception(e)
        return 1
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        if not args.single_img:
            if not os.path.isdir(args.image) or not os.path.isdir(
                    args.brain_mask):
                raise NormalizationError(
                    'if single-img option off, then image and brain-mask must be directories'
                )
            img_fns = io.glob_nii(args.image)
            mask_fns = io.glob_nii(args.brain_mask)
            if len(img_fns) != len(mask_fns) and len(img_fns) > 0:
                raise NormalizationError(
                    'input images and masks must be in correspondence and greater than zero '
                    '({:d} != {:d})'.format(len(img_fns), len(mask_fns)))
            for i, (img, mask) in enumerate(zip(img_fns, mask_fns), 1):
                _, base, _ = io.split_filename(img)
                logger.info('Normalizing image {} ({:d}/{:d})'.format(
                    img, i, len(img_fns)))
                process(img, mask, args, logger)
        else:
            if not os.path.isfile(args.image) or not os.path.isfile(
                    args.brain_mask):
                raise NormalizationError(
                    'if single-img option on, then image and brain-mask must be files'
                )
            process(args.image, args.brain_mask, args, logger)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(args.output_dir, args.brain_mask)
            ax.set_title('GMM')
            plt.savefig(os.path.join(args.output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Example #9
0
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    try:
        img_fns = glob_nii(args.img_dir)
        if not os.path.exists(args.output_dir):
            logger.info('Making Output Directory: {}'.format(args.output_dir))
            os.mkdir(args.output_dir)
        if args.template_dir is None:
            logger.info('Registering image to MNI template')
            template = ants.image_read(ants.get_ants_data('mni')).reorient_image2(args.orientation)
            orientation = args.orientation
        else:
            template_fns = glob_nii(args.template_dir)
            if len(template_fns) != len(img_fns):
                raise NormalizationError('If template images are provided, they must be in '
                                         'correspondence (i.e., equal number) with the source images')
        for i, img in enumerate(img_fns):
            _, base, _ = split_filename(img)
            logger.info('Registering image to template: {} ({:d}/{:d})'.format(base, i+1, len(img_fns)))
            if args.template_dir is not None:
                template = ants.image_read(template_fns[i])
                orientation = template.orientation if hasattr(template, 'orientation') else None
            input_img = ants.image_read(img)
            input_img = input_img.reorient_image2(orientation) if orientation is not None else input_img
            if not args.no_rigid:
                logger.info('Starting rigid registration: {} ({:d}/{:d})'.format(base, i+1, len(img_fns)))
                mytx = ants.registration(fixed=template, moving=input_img, type_of_transform="Rigid")
                tx = mytx['fwdtransforms'][0]
            else:
                tx = None
            logger.info('Starting {} registration: {} ({:d}/{:d})'.format(args.registration, base, i+1, len(img_fns)))
            mytx = ants.registration(fixed=template, moving=input_img, initial_transform=tx, type_of_transform=args.registration)
            logger.debug(mytx)
            moved = ants.apply_transforms(template, input_img, mytx['fwdtransforms'], interpolator='bSpline')
            registered = os.path.join(args.output_dir, base + '_reg.nii.gz')
            ants.image_write(moved, registered)
        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Example #10
0
def process(image_fn, brain_mask_fn, args, logger):
    img = io.open_nii(image_fn)
    if args.brain_mask is not None:
        mask = io.open_nii(brain_mask_fn)
    else:
        mask = None
    dirname, base, _ = io.split_filename(image_fn)
    if args.output_dir is not None:
        dirname = args.output_dir
        if not os.path.exists(dirname):
            logger.info('Making output directory: {}'.format(dirname))
            os.mkdir(dirname)
    normalized = kde.kde_normalize(img, mask, args.contrast, args.norm_value)
    outfile = os.path.join(dirname, base + '_kde.nii.gz')
    logger.info('Normalized image saved: {}'.format(outfile))
    io.save_nii(normalized, outfile, is_nii=True)
Example #11
0
def process(image_fn, brain_mask_fn, output_dir, logger):
    img = io.open_nii(image_fn)
    dirname, base, _ = io.split_filename(image_fn)
    if output_dir is not None:
        dirname = output_dir
        if not os.path.exists(dirname):
            logger.info('Making output directory: {}'.format(dirname))
            os.mkdir(dirname)
    if brain_mask_fn is None:
        mask = None
    else:
        if brain_mask_fn == 'nomask':
            mask = 'nomask'
        else:
            mask = io.open_nii(brain_mask_fn)
    normalized = zscore.zscore_normalize(img, mask)
    outfile = os.path.join(dirname, base + '_zscore.nii.gz')
    logger.info('Normalized image saved: {}'.format(outfile))
    io.save_nii(normalized, outfile, is_nii=True)
Example #12
0
def process(image_fn, brain_mask_fn, wm_mask_fn, output_dir, args, logger):
    img = io.open_nii(image_fn)
    dirname, base, _ = io.split_filename(image_fn)
    if output_dir is not None:
        dirname = output_dir
        if not os.path.exists(dirname):
            logger.info('Making output directory: {}'.format(dirname))
            os.mkdir(dirname)
    if brain_mask_fn is not None:
        mask = io.open_nii(brain_mask_fn)
        wm_mask = fcm.find_wm_mask(img, mask)
        outfile = os.path.join(dirname, base + '_wmmask.nii.gz')
        io.save_nii(wm_mask, outfile, is_nii=True)
    if wm_mask_fn is not None:
        wm_mask = io.open_nii(wm_mask_fn)
        normalized = fcm.fcm_normalize(img, wm_mask, args.norm_value)
        outfile = os.path.join(dirname, base + '_fcm.nii.gz')
        logger.info('Normalized image saved: {}'.format(outfile))
        io.save_nii(normalized, outfile, is_nii=True)
Example #13
0
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    try:
        img_fns = glob_nii(args.img_dir)
        if not os.path.exists(args.mask_dir):
            logger.info('Making Output Mask Directory: {}'.format(args.mask_dir))
            os.mkdir(args.mask_dir)
        for i, img in enumerate(img_fns, 1):
            _, base, _ = split_filename(img)
            logger.info('Creating Mask for Image: {}, ({:d}/{:d})'.format(base, i, len(img_fns)))
            mask = os.path.join(args.mask_dir, base + '_mask.nii.gz')
            _ = robex(os.path.abspath(img), os.path.abspath(mask), args.return_skull_stripped)
        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Example #14
0
def preprocess(img_dir,
               out_dir,
               mask_dir=None,
               res=(1., 1., 1.),
               orientation='RAI',
               n4_opts=None):
    """
    preprocess.py MR images according to a simple scheme,
    that is:
        1) N4 bias field correction
        2) resample to x mm x y mm x z mm
        3) reorient images to RAI

    Args:
        img_dir (str): path to directory containing images
        out_dir (str): path to directory for output preprocessed files
        mask_dir (str): path to directory containing masks
        res (tuple): resolution for resampling (default: (1,1,1) in mm)
        n4_opts (dict): n4 processing options. See ANTsPy for details. (default: None)

    Returns:
        None, outputs preprocessed images to file in given out_dir
    """

    if n4_opts is None:
        n4_opts = {'iters': [200, 200, 200, 200], 'tol': 0.0005}
    logger.debug('N4 Options are: {}'.format(n4_opts))

    # get and check the images and masks
    img_fns = glob_nii(img_dir)
    mask_fns = glob_nii(
        mask_dir) if mask_dir is not None else [None] * len(img_fns)
    assert len(img_fns) == len(mask_fns), 'Number of images and masks must be equal ({:d} != {:d})' \
        .format(len(img_fns), len(mask_fns))

    # create the output directory structure
    out_img_dir = os.path.join(out_dir, 'imgs')
    out_mask_dir = os.path.join(out_dir, 'masks')
    if not os.path.exists(out_dir):
        logger.info('Making output directory structure: {}'.format(out_dir))
        os.mkdir(out_dir)
    if not os.path.exists(out_img_dir):
        logger.info('Making image output directory: {}'.format(out_img_dir))
        os.mkdir(out_img_dir)
    if not os.path.exists(out_mask_dir) and mask_dir is not None:
        logger.info('Making mask output directory: {}'.format(out_mask_dir))
        os.mkdir(out_mask_dir)

    # preprocess the images by n4 correction, resampling, and reorientation
    for i, (img_fn, mask_fn) in enumerate(zip(img_fns, mask_fns), 1):
        _, img_base, img_ext = split_filename(img_fn)
        logger.info('Preprocessing image: {} ({:d}/{:d})'.format(
            img_base, i, len(img_fns)))
        img = ants.image_read(img_fn)
        if mask_dir is not None:
            _, mask_base, mask_ext = split_filename(mask_fn)
            mask = ants.image_read(mask_fn)
            smoothed_mask = ants.smooth_image(mask, 1)
            # this should be a second n4 after an initial n4 (and coregistration), once masks are obtained
            img = ants.n4_bias_field_correction(img,
                                                convergence=n4_opts,
                                                weight_mask=smoothed_mask)
            if res is not None:
                if res != img.spacing:
                    mask = ants.resample_image(mask, res, False, 1)
            mask = mask.reorient_image2(orientation) if hasattr(img, 'reorient_image2') else \
                mask.reorient_image((1, 0, 0))['reoimage']
            out_mask = os.path.join(out_mask_dir, mask_base + mask_ext)
            ants.image_write(mask, out_mask)
        else:
            img = ants.n4_bias_field_correction(img, convergence=n4_opts)
        if res is not None:
            if res != img.spacing:
                img = ants.resample_image(img, res, False, 4)
        if hasattr(img, 'reorient_image2'):
            img = img.reorient_image2(orientation)
        else:
            logger.info(
                'Cannot reorient image to a custom orientation. Update ANTsPy to a version >= 0.1.5.'
            )
            img = img.reorient_image((1, 0, 0))['reoimage']
        logger.info('Writing preprocessed image: {} ({:d}/{:d})'.format(
            img_base, i, len(img_fns)))
        out_img = os.path.join(out_img_dir, img_base + img_ext)
        ants.image_write(img, out_img)
def ravel_normalize(img_dir,
                    mask_dir,
                    contrast,
                    output_dir=None,
                    write_to_disk=False,
                    do_whitestripe=True,
                    b=1,
                    membership_thresh=0.99,
                    segmentation_smoothness=0.25,
                    do_registration=False,
                    use_fcm=True):
    """
    Use RAVEL [1] to normalize the intensities of a set of MR images to eliminate
    unwanted technical variation in images (but, hopefully, preserve biological variation)

    this function has an option that is modified from [1] in where no registration is done,
    the control mask is defined dynamically by finding a tissue segmentation of the brain and
    thresholding the membership at a very high level (this seems to work well and is *much* faster)
    but there seems to be some more inconsistency in the results

    Args:
        img_dir (str): directory containing MR images to be normalized
        mask_dir (str): brain masks for imgs
        contrast (str): contrast of MR images to be normalized (T1, T2, or FLAIR)
        output_dir (str): directory to save images if you do not want them saved in
            same directory as data_dir
        write_to_disk (bool): write the normalized data to disk or nah
        do_whitestripe (bool): whitestripe normalize the images before applying RAVEL correction
        b (int): number of unwanted factors to estimate
        membership_thresh (float): threshold of membership for control voxels
        segmentation_smoothness (float): segmentation smoothness parameter for atropos ANTsPy
            segmentation scheme (i.e., mrf parameter)
        do_registration (bool): deformably register images to find control mask
        use_fcm (bool): use FCM for segmentation instead of atropos (may be less accurate)

    Returns:
        Z (np.ndarray): unwanted factors (used in ravel correction)
        normalized (np.ndarray): set of normalized images from data_dir

    References:
        [1] J. P. Fortin, E. M. Sweeney, J. Muschelli, C. M. Crainiceanu,
            and R. T. Shinohara, “Removing inter-subject technical variability
            in magnetic resonance imaging studies,” Neuroimage, vol. 132,
            pp. 198–212, 2016.
    """
    img_fns = io.glob_nii(img_dir)
    mask_fns = io.glob_nii(mask_dir)

    if output_dir is None or not write_to_disk:
        out_fns = None
    else:
        out_fns = []
        for fn in img_fns:
            _, base, ext = io.split_filename(fn)
            out_fns.append(os.path.join(output_dir, base + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

    # get parameters necessary and setup the V array
    V, Vc = image_matrix(img_fns,
                         contrast,
                         masks=mask_fns,
                         do_whitestripe=do_whitestripe,
                         return_ctrl_matrix=True,
                         membership_thresh=membership_thresh,
                         do_registration=do_registration,
                         smoothness=segmentation_smoothness,
                         use_fcm=use_fcm)

    # estimate the unwanted factors Z
    _, _, vh = np.linalg.svd(Vc)
    Z = vh.T[:, 0:b]

    # perform the ravel correction
    V_norm = ravel_correction(V, Z)

    # save the results to disk if desired
    if write_to_disk:
        for i, (img_fn, out_fn) in enumerate(zip(img_fns, out_fns)):
            img = io.open_nii(img_fn)
            norm = V_norm[:, i].reshape(img.get_data().shape)
            io.save_nii(img, out_fn, data=norm)

    return Z, V_norm
def ws_normalize(img_dir,
                 contrast,
                 mask_dir=None,
                 output_dir=None,
                 write_to_disk=True):
    """
    Use WhiteStripe normalization method ([1]) to normalize the intensities of
    a set of MR images by normalizing an area around the white matter peak of the histogram

    Args:
        img_dir (str): directory containing MR images to be normalized
        contrast (str): contrast of MR images to be normalized (T1, T2, or FLAIR)
        mask_dir (str): if images are not skull-stripped, then provide brain mask
        output_dir (str): directory to save images if you do not want them saved in
            same directory as img_dir
        write_to_disk (bool): write the normalized data to disk or nah

    Returns:
        normalized (np.ndarray): last normalized image data from img_dir
            I know this is an odd behavior, but yolo

    References:
        [1] R. T. Shinohara, E. M. Sweeney, J. Goldsmith, N. Shiee,
            F. J. Mateen, P. A. Calabresi, S. Jarso, D. L. Pham,
            D. S. Reich, and C. M. Crainiceanu, “Statistical normalization
            techniques for magnetic resonance imaging,” NeuroImage Clin.,
            vol. 6, pp. 9–19, 2014.
    """

    # grab the file names for the images of interest
    data = io.glob_nii(img_dir)

    # define and get the brain masks for the images, if defined
    if mask_dir is None:
        masks = [None] * len(data)
    else:
        masks = io.glob_nii(mask_dir)
        if len(data) != len(masks):
            raise NormalizationError(
                'Number of images and masks must be equal, Images: {}, Masks: {}'
                .format(len(data), len(masks)))

    # define the output directory and corresponding output file names
    if output_dir is None:
        output_files = [None] * len(data)
    else:
        output_files = []
        for fn in data:
            _, base, ext = io.split_filename(fn)
            output_files.append(os.path.join(output_dir, base + '_ws' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

    # do whitestripe normalization and save the results
    for i, (img_fn, mask_fn,
            output_fn) in enumerate(zip(data, masks, output_files), 1):
        logger.info('Normalizing image: {} ({:d}/{:d})'.format(
            img_fn, i, len(data)))
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        indices = whitestripe(img, contrast, mask=mask)
        normalized = whitestripe_norm(img, indices)
        if write_to_disk:
            logger.info('Saving normalized image: {} ({:d}/{:d})'.format(
                output_fn, i, len(data)))
            io.save_nii(normalized, output_fn)

    # output the last normalized image (mostly for testing purposes)
    return normalized
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        img_fns = io.glob_nii(args.img_dir)
        mask_fns = io.glob_nii(args.mask_dir)
        if len(img_fns) != len(mask_fns) or len(img_fns) == 0:
            raise NormalizationError(
                'Image directory ({}) and mask directory ({}) must contain the same '
                '(positive) number of images!'.format(args.img_dir,
                                                      args.mask_dir))

        logger.info('Normalizing the images according to RAVEL')
        Z, _ = ravel.ravel_normalize(
            args.img_dir,
            args.mask_dir,
            args.contrast,
            do_whitestripe=args.no_whitestripe,
            b=args.num_unwanted_factors,
            membership_thresh=args.control_membership_threshold,
            do_registration=args.no_registration,
            segmentation_smoothness=args.segmentation_smoothness,
            use_fcm=not args.use_atropos,
            sparse_svd=args.sparse_svd,
            csf_masks=args.csf_masks)

        V = ravel.image_matrix(img_fns, args.contrast, masks=mask_fns)
        V_norm = ravel.ravel_correction(V, Z)
        normalized = ravel.image_matrix_to_images(V_norm, img_fns)

        # save the normalized images to disk
        output_dir = os.getcwd(
        ) if args.output_dir is None else args.output_dir
        out_fns = []
        for fn in img_fns:
            _, base, ext = io.split_filename(fn)
            out_fns.append(os.path.join(output_dir, base + '_ravel' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        for norm, out_fn in zip(normalized, out_fns):
            norm.to_filename(out_fn)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(output_dir, args.mask_dir)
            ax.set_title('RAVEL')
            plt.savefig(os.path.join(output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
def image_matrix(imgs,
                 contrast,
                 masks=None,
                 do_whitestripe=True,
                 return_ctrl_matrix=False,
                 membership_thresh=0.99,
                 smoothness=0.25,
                 max_ctrl_vox=10000,
                 do_registration=False,
                 ctrl_prob=1,
                 use_fcm=False):
    """
    creates an matrix of images where the rows correspond the the voxels of
    each image and the columns are the images

    Args:
        imgs (list): list of paths to MR images of interest
        contrast (str): contrast of the set of imgs (e.g., T1)
        masks (list or str): list of corresponding brain masks or just one (template) mask
        do_whitestripe (bool): do whitestripe on the images before storing in matrix or nah
        return_ctrl_matrix (bool): return control matrix for imgs (i.e., a subset of V's rows)
        membership_thresh (float): threshold of membership for control voxels (want this very high)
            this option is only used if the registration is turned off
        smoothness (float): smoothness parameter for segmentation for control voxels
            this option is only used if the registration is turned off
        max_ctrl_vox (int): maximum number of control voxels (if too high, everything
            crashes depending on available memory) only used if do_registration is false
        do_registration (bool): register the images together and take the intersection of the csf
            masks (as done in the original paper, note that this takes much longer)
        ctrl_prob (float): given all data, proportion of data labeled as csf to be
            used for intersection (i.e., when do_registration is true)
        use_fcm (bool): use FCM for segmentation instead of atropos (may be less accurate)

    Returns:
        V (np.ndarray): image matrix (rows are voxels, columns are images)
        Vc (np.ndarray): image matrix of control voxels (rows are voxels, columns are images)
            Vc only returned if return_ctrl_matrix is True
    """
    img_shape = io.open_nii(imgs[0]).get_data().shape
    V = np.zeros((int(np.prod(img_shape)), len(imgs)))

    if return_ctrl_matrix:
        ctrl_vox = []

    if masks is None and return_ctrl_matrix:
        raise NormalizationError(
            'Brain masks must be provided if returning control memberships')
    if masks is None:
        masks = [None] * len(imgs)

    for i, (img_fn, mask_fn) in enumerate(zip(imgs, masks)):
        _, base, _ = io.split_filename(img_fn)
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        # do whitestripe on the image before applying RAVEL (if desired)
        if do_whitestripe:
            logger.info('Applying WhiteStripe to image {} ({:d}/{:d})'.format(
                base, i + 1, len(imgs)))
            inds = whitestripe(img, contrast, mask)
            img = whitestripe_norm(img, inds)
        img_data = img.get_data()
        if img_data.shape != img_shape:
            raise NormalizationError(
                'Cannot normalize because image {} needs to have same dimension '
                'as all other images ({} != {})'.format(
                    base, img_data.shape, img_shape))
        V[:, i] = img_data.flatten()
        if return_ctrl_matrix:
            if do_registration and i == 0:
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                verbose = True if logger.getEffectiveLevel(
                ) == logging.getLevelName('DEBUG') else False
                ctrl_masks = []
                reg_imgs = []
                reg_imgs.append(csf.nibabel_to_ants(img))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            elif do_registration and i != 0:
                template = ants.image_read(imgs[0])
                tmask = ants.image_read(masks[0])
                img = csf.nibabel_to_ants(img)
                logger.info(
                    'Starting registration for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                reg_result = ants.registration(template,
                                               img,
                                               type_of_transform='SyN',
                                               mask=tmask,
                                               verbose=verbose)
                img = reg_result['warpedmovout']
                mask = csf.nibabel_to_ants(mask)
                reg_imgs.append(img)
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            else:
                logger.info(
                    'Finding control voxels for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_mask = csf.csf_mask(img,
                                         mask,
                                         contrast=contrast,
                                         csf_thresh=membership_thresh,
                                         mrf=smoothness,
                                         use_fcm=use_fcm)
                if np.sum(ctrl_mask) == 0:
                    raise NormalizationError(
                        'No control voxels found for image ({}) at threshold ({})'
                        .format(base, membership_thresh))
                elif np.sum(ctrl_mask) < 100:
                    logger.warning(
                        'Few control voxels found ({:d}) (potentially a problematic image ({}) or '
                        'threshold ({}) too high)'.format(
                            int(np.sum(ctrl_mask)), base, membership_thresh))
                ctrl_vox.append(img_data[ctrl_mask == 1].flatten())

    if return_ctrl_matrix and not do_registration:
        min_len = min(min(map(len, ctrl_vox)), max_ctrl_vox)
        logger.info('Using {:d} control voxels'.format(min_len))
        Vc = np.zeros((min_len, len(imgs)))
        for i in range(len(imgs)):
            ctrl_voxs = ctrl_vox[i][:min_len]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
    elif return_ctrl_matrix and do_registration:
        ctrl_sum = reduce(
            add,
            ctrl_masks)  # need to use reduce instead of sum b/c data structure
        intersection = np.zeros(ctrl_sum.shape)
        intersection[ctrl_sum >= np.floor(len(ctrl_masks) * ctrl_prob)] = 1
        num_ctrl_vox = int(np.sum(intersection))
        Vc = np.zeros((num_ctrl_vox, len(imgs)))
        for i, img in enumerate(reg_imgs):
            ctrl_voxs = img.numpy()[intersection == 1]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
        del ctrl_masks, reg_imgs
        import gc
        gc.collect(
        )  # force a garbage collection, since we just used the majority of the system memory

    return V if not return_ctrl_matrix else (V, Vc)
def main(args=None):
    args = arg_parser().parse_args(args)
    if not (args.brain_mask is None) ^ (args.tissue_mask is None):
        raise NormalizationError(
            'Only one of {brain mask, tissue mask} should be given')
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        if not args.single_img:
            if not os.path.isdir(
                    args.image) or (False if args.brain_mask is None else
                                    not os.path.isdir(args.brain_mask)):
                raise NormalizationError(
                    'if single-img option off, then image and brain-mask must be directories'
                )
            img_fns = io.glob_nii(args.image)
            mask_fns = io.glob_nii(
                args.brain_mask
            ) if args.brain_mask is not None else [None] * len(img_fns)
            if len(img_fns) != len(mask_fns) and len(img_fns) > 0:
                raise NormalizationError(
                    'input images and masks must be in correspondence and greater than zero '
                    '({:d} != {:d})'.format(len(img_fns), len(mask_fns)))
            args.output_dir = args.output_dir or 'fcm'
            output_dir_base = os.path.abspath(
                os.path.join(args.output_dir, '..'))

            if args.contrast.lower() == 't1' and args.tissue_mask is None:
                tissue_mask_dir = os.path.join(output_dir_base, 'tissue_masks')
                if os.path.exists(tissue_mask_dir):
                    logger.warning(
                        'Tissue mask directory already exists, may overwrite existing tissue masks!'
                    )
                else:
                    logger.info('Creating tissue mask directory: {}'.format(
                        tissue_mask_dir))
                    os.mkdir(tissue_mask_dir)
                for i, (img, mask) in enumerate(zip(img_fns, mask_fns), 1):
                    _, base, _ = io.split_filename(img)
                    _, mask_base, _ = io.split_filename(mask)
                    logger.info(
                        'Creating tissue mask for {} ({:d}/{:d})'.format(
                            base, i, len(img_fns)))
                    logger.debug('Tissue mask {} ({:d}/{:d})'.format(
                        mask_base, i, len(img_fns)))
                    process(img, mask, None, tissue_mask_dir, args, logger)
            elif os.path.exists(args.tissue_mask):
                tissue_mask_dir = args.tissue_mask
            else:
                raise NormalizationError(
                    'If contrast is not t1, then tissue mask directory ({}) '
                    'must already be created!'.format(args.tissue_mask))

            tissue_masks = io.glob_nii(tissue_mask_dir)
            for i, (img, tissue_mask) in enumerate(zip(img_fns, tissue_masks),
                                                   1):
                dirname, base, _ = io.split_filename(img)
                _, tissue_base, _ = io.split_filename(tissue_mask)
                logger.info('Normalizing image {} ({:d}/{:d})'.format(
                    base, i, len(img_fns)))
                logger.debug('Tissue mask {} ({:d}/{:d})'.format(
                    tissue_base, i, len(img_fns)))
                if args.output_dir is not None:
                    dirname = args.output_dir
                process(img, None, tissue_mask, dirname, args, logger)

        else:
            if not os.path.isfile(args.image):
                raise NormalizationError(
                    'if single-img option on, then image must be a file')
            if args.tissue_mask is None and args.contrast.lower() == 't1':
                logger.info('Creating tissue mask for {}'.format(args.image))
                process(args.image, args.brain_mask, None, args.output_dir,
                        args, logger)
            elif os.path.isfile(args.tissue_mask):
                pass
            else:
                raise NormalizationError(
                    'If contrast is not t1, then tissue mask must be provided!'
                )
            logger.info('Normalizing image {}'.format(args.image))
            dirname, base, _ = io.split_filename(args.image)
            dirname = args.output_dir or dirname
            if args.tissue_mask is None:
                tissue_mask = os.path.join(
                    dirname, base + '_{}_mask.nii.gz'.format(args.tissue_type))
            else:
                tissue_mask = args.tissue_mask
            process(args.image, args.brain_mask, tissue_mask, dirname, args,
                    logger)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(args.output_dir, args.brain_mask)
            ax.set_title('Fuzzy C-Means')
            plt.savefig(os.path.join(args.output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1