def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        if not args.single_img:
            if not os.path.isdir(args.image) or not os.path.isdir(
                    args.brain_mask):
                raise NormalizationError(
                    'if single-img option off, then image and brain-mask must be directories'
                )
            img_fns = io.glob_nii(args.image)
            mask_fns = io.glob_nii(args.brain_mask)
            if len(img_fns) != len(mask_fns) and len(img_fns) > 0:
                raise NormalizationError(
                    'input images and masks must be in correspondence and greater than zero '
                    '({:d} != {:d})'.format(len(img_fns), len(mask_fns)))
            for i, (img, mask) in enumerate(zip(img_fns, mask_fns), 1):
                _, base, _ = io.split_filename(img)
                logger.info('Normalizing image {} ({:d}/{:d})'.format(
                    img, i, len(img_fns)))
                process(img, mask, args, logger)
        else:
            if not os.path.isfile(args.image) or not os.path.isfile(
                    args.brain_mask):
                raise NormalizationError(
                    'if single-img option on, then image and brain-mask must be files'
                )
            process(args.image, args.brain_mask, args, logger)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(args.output_dir, args.brain_mask)
            ax.set_title('GMM')
            plt.savefig(os.path.join(args.output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Esempio n. 2
0
def whitestripe(img,
                contrast,
                mask=None,
                width=0.05,
                width_l=None,
                width_u=None):
    """
    find the "(normal appearing) white (matter) stripe" of the input MR image
    and return the indices

    Args:
        img (nibabel.nifti1.Nifti1Image): target MR image
        contrast (str): contrast of img (e.g., T1)
        mask (nibabel.nifti1.Nifti1Image): brainmask for img (None is default, for skull-stripped img)
        width (float): width quantile for the "white (matter) stripe"
        width_l (float): lower bound for width (default None, derives from width)
        width_u (float): upper bound for width (default None, derives from width)

    Returns:
        ws_ind (np.ndarray): the white stripe indices (boolean mask)
    """
    if width_l is None and width_u is None:
        width_l = width
        width_u = width
    img_data = img.get_data()
    if mask is not None:
        mask_data = mask.get_data()
        masked = img_data * mask_data
        voi = img_data[mask_data == 1]
    else:
        masked = img_data
        voi = img_data[img_data > img_data.mean()]
    if contrast.lower() in ['t1', 'last']:
        mode = hist.get_last_mode(voi)
    elif contrast.lower() in ['t2', 'flair', 'largest']:
        mode = hist.get_largest_mode(voi)
    elif contrast.lower() in ['md', 'first']:
        mode = hist.get_first_mode(voi)
    else:
        raise NormalizationError(
            'Contrast {} not valid, needs to be `t1`,`t2`,`flair`,`md`,`first`,`largest`,`last`'
            .format(contrast))
    img_mode_q = np.mean(voi < mode)
    ws = np.percentile(voi, (max(img_mode_q - width_l, 0) * 100,
                             min(img_mode_q + width_u, 1) * 100))
    ws_ind = np.logical_and(masked > ws[0], masked < ws[1])
    if len(ws_ind) == 0:
        raise NormalizationError(
            'WhiteStripe failed to find any valid indices!')
    return ws_ind
Esempio n. 3
0
def kde_normalize(img, mask=None, contrast='t1', norm_value=1):
    """
    use kernel density estimation to find the peak of the white
    matter in the histogram of a skull-stripped image. Normalize
    the WM of the non-skull-stripped image to norm_value

    Args:
        img (nibabel.nifti1.Nifti1Image): target MR image
        mask (nibabel.nifti1.Nifti1Image): brain mask of img
        contrast (str): contrast of img (T1,T2,FA,MD)
        norm_value (float): value at which to place WM peak

    Returns:
        normalized (nibabel.nifti1.Nifti1Image): WM normalized img
    """
    if mask is not None:
        voi = img.get_fdata()[mask.get_fdata() == 1].flatten()
    else:
        voi = img.get_fdata()[img.get_fdata() > img.get_fdata().mean()].flatten()
    if contrast.lower() in ['t1', 'flair', 'last']:
        wm_peak = hist.get_last_mode(voi)
    elif contrast.lower() in ['t2', 'largest']:
        wm_peak = hist.get_largest_mode(voi)
    elif contrast.lower() in ['md', 'first']:
        wm_peak = hist.get_first_mode(voi)
    else:
        raise NormalizationError(
            'Contrast {} not valid, needs to be `t1`,`t2`,`flair`,`md`,`first`,`largest`,`last`'.format(contrast))
    normalized = nib.Nifti1Image((img.get_fdata() / wm_peak) * norm_value,
                                 img.affine, img.header)
    return normalized
Esempio n. 4
0
def csf_mask_intersection(img_dir, masks=None, prob=1):
    """
    use all nifti T1w images in data_dir to create csf mask in common areas

    Args:
        img_dir (str): directory containing MR images to be normalized
        masks (str or ants.core.ants_image.ANTsImage): if images are not skull-stripped,
            then provide brain mask as either a corresponding directory or an individual mask
        prob (float): given all data, proportion of data labeled as csf to be
            used for intersection

    Returns:
        intersection (np.ndarray): binary mask of common csf areas for all provided imgs
    """
    if not (0 <= prob <= 1):
        raise NormalizationError(
            'prob must be between 0 and 1. {} given.'.format(prob))
    data = io.glob_nii(img_dir)
    masks = io.glob_nii(masks) if isinstance(masks,
                                             str) else [masks] * len(data)
    csf = []
    for i, (img, mask) in enumerate(zip(data, masks)):
        _, base, _ = io.split_filename(img)
        logger.info('Creating CSF mask for image {} ({:d}/{:d})'.format(
            base, i + 1, len(data)))
        imgn = ants.image_read(img)
        maskn = ants.image_read(mask) if isinstance(mask, str) else mask
        csf.append(csf_mask(imgn, maskn))
    csf_sum = reduce(
        add, csf)  # need to use reduce instead of sum b/c data structure
    intersection = np.zeros(csf_sum.shape)
    intersection[csf_sum >= np.floor(len(data) * prob)] = 1
    return intersection
Esempio n. 5
0
def get_tissue_mode(data: Array, modality: str) -> float:
    """Find the appropriate tissue mode given a modality"""
    modality_ = modality.lower()
    if modality_ in PEAK["last"]:
        mode = get_last_tissue_mode(data)
    elif modality_ in PEAK["largest"]:
        mode = get_largest_tissue_mode(data)
    elif modality_ in PEAK["first"]:
        mode = get_first_tissue_mode(data)
    else:
        modalities = ", ".join(VALID_PEAKS)
        msg = f"Modality {modality} not valid. Needs to be one of {modalities}."
        raise NormalizationError(msg)
    return mode
Esempio n. 6
0
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    try:
        img_fns = glob_nii(args.img_dir)
        if not os.path.exists(args.output_dir):
            logger.info('Making Output Directory: {}'.format(args.output_dir))
            os.mkdir(args.output_dir)
        if args.template_dir is None:
            logger.info('Registering image to MNI template')
            template = ants.image_read(ants.get_ants_data('mni')).reorient_image2(args.orientation)
            orientation = args.orientation
        else:
            template_fns = glob_nii(args.template_dir)
            if len(template_fns) != len(img_fns):
                raise NormalizationError('If template images are provided, they must be in '
                                         'correspondence (i.e., equal number) with the source images')
        for i, img in enumerate(img_fns):
            _, base, _ = split_filename(img)
            logger.info('Registering image to template: {} ({:d}/{:d})'.format(base, i+1, len(img_fns)))
            if args.template_dir is not None:
                template = ants.image_read(template_fns[i])
                orientation = template.orientation if hasattr(template, 'orientation') else None
            input_img = ants.image_read(img)
            input_img = input_img.reorient_image2(orientation) if orientation is not None else input_img
            if not args.no_rigid:
                logger.info('Starting rigid registration: {} ({:d}/{:d})'.format(base, i+1, len(img_fns)))
                mytx = ants.registration(fixed=template, moving=input_img, type_of_transform="Rigid")
                tx = mytx['fwdtransforms'][0]
            else:
                tx = None
            logger.info('Starting {} registration: {} ({:d}/{:d})'.format(args.registration, base, i+1, len(img_fns)))
            mytx = ants.registration(fixed=template, moving=input_img, initial_transform=tx, type_of_transform=args.registration)
            logger.debug(mytx)
            moved = ants.apply_transforms(template, input_img, mytx['fwdtransforms'], interpolator='bSpline')
            registered = os.path.join(args.output_dir, base + '_reg.nii.gz')
            ants.image_write(moved, registered)
        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Esempio n. 7
0
def pairwise_jsd(img_dir, mask_dir, nbins=200):
    """
    Calculate the Jensen-Shannon Divergence for all pairs of images in the image directory

    Args:
        img_dir (str): path to directory of images
        mask_dir (str): path to directory of masks
        nbins (int): number of bins to use in the histograms

    Returns:
        pairwise_jsd (np.ndarray): array of pairwise Jensen-Shannon divergence
    """
    eps = np.finfo(np.float32).eps

    img_fns = io.glob_nii(img_dir)
    mask_fns = io.glob_nii(mask_dir)

    if len(img_fns) != len(mask_fns):
        raise NormalizationError(
            f'Number of images ({len(img_fns)}) must be equal to the number of masks ({len(mask_fns)}).'
        )

    min_intensities, max_intensities = [], []
    for img_fn, mask_fn in zip(img_fns, mask_fns):
        data = nib.load(img_fn).get_fdata()[nib.load(mask_fn).get_fdata() == 1]
        min_intensities.append(np.min(data))
        max_intensities.append(np.max(data))
    intensity_range = (min(min_intensities), max(max_intensities))

    hists = []
    for img_fn, mask_fn in zip(img_fns, mask_fns):
        data = nib.load(img_fn).get_fdata()[nib.load(mask_fn).get_fdata() == 1]
        hist, _ = np.histogram(data.flatten(),
                               nbins,
                               range=intensity_range,
                               density=True)
        hists.append(hist + eps)

    pairwise_jsd = []
    for i in range(len(hists)):
        for j in range(i + 1, len(hists)):
            pairwise_jsd.append(jsd(hists[i], hists[j]))

    return np.array(pairwise_jsd)
Esempio n. 8
0
def all_hists(img_dir, mask_dir=None, alpha=0.8, figsize=(12, 10), **kwargs):
    """
    plot all histograms over one another to get an idea of the
    spread for a sample/population

    note that all histograms are for the intensities within a given brain mask
    or estimated foreground mask (the estimate is just all intensities above the mean)

    Args:
        img_dir (str): path to images
        mask_dir (str): path to corresponding masks of imgs
        alpha (float): controls alpha parameter of individual line plots (default: 0.8)
        figsize (tuple): size of figure (default: (12,10))
        **kwargs: for numpy histogram routine

    Returns:
        ax (matplotlib.axes.Axes): plotted on ax obj
    """
    imgs = glob_nii(img_dir)
    if mask_dir is not None:
        masks = glob_nii(mask_dir)
    else:
        masks = [None] * len(imgs)
    if len(imgs) != len(masks):
        raise NormalizationError(
            'Number of images and masks must be equal ({:d} != {:d})'.format(
                len(imgs), len(masks)))
    _, ax = plt.subplots(figsize=figsize)
    for i, (img_fn, mask_fn) in enumerate(zip(imgs, masks), 1):
        logger.info('Creating histogram for image {:d}/{:d}'.format(
            i, len(imgs)))
        img = nib.load(img_fn)
        if mask_fn is not None:
            mask = nib.load(mask_fn)
        else:
            mask = None
        _ = hist(img, mask, ax=ax, alpha=alpha, **kwargs)
    ax.set_xlabel('Intensity')
    ax.set_ylabel(r'Log$_{10}$ Count')
    ax.set_ylim((0, None))
    return ax
def ws_normalize(img_dir,
                 contrast,
                 mask_dir=None,
                 output_dir=None,
                 write_to_disk=True):
    """
    Use WhiteStripe normalization method ([1]) to normalize the intensities of
    a set of MR images by normalizing an area around the white matter peak of the histogram

    Args:
        img_dir (str): directory containing MR images to be normalized
        contrast (str): contrast of MR images to be normalized (T1, T2, or FLAIR)
        mask_dir (str): if images are not skull-stripped, then provide brain mask
        output_dir (str): directory to save images if you do not want them saved in
            same directory as img_dir
        write_to_disk (bool): write the normalized data to disk or nah

    Returns:
        normalized (np.ndarray): last normalized image data from img_dir
            I know this is an odd behavior, but yolo

    References:
        [1] R. T. Shinohara, E. M. Sweeney, J. Goldsmith, N. Shiee,
            F. J. Mateen, P. A. Calabresi, S. Jarso, D. L. Pham,
            D. S. Reich, and C. M. Crainiceanu, “Statistical normalization
            techniques for magnetic resonance imaging,” NeuroImage Clin.,
            vol. 6, pp. 9–19, 2014.
    """

    # grab the file names for the images of interest
    data = io.glob_nii(img_dir)

    # define and get the brain masks for the images, if defined
    if mask_dir is None:
        masks = [None] * len(data)
    else:
        masks = io.glob_nii(mask_dir)
        if len(data) != len(masks):
            raise NormalizationError(
                'Number of images and masks must be equal, Images: {}, Masks: {}'
                .format(len(data), len(masks)))

    # define the output directory and corresponding output file names
    if output_dir is None:
        output_files = [None] * len(data)
    else:
        output_files = []
        for fn in data:
            _, base, ext = io.split_filename(fn)
            output_files.append(os.path.join(output_dir, base + '_ws' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

    # do whitestripe normalization and save the results
    for i, (img_fn, mask_fn,
            output_fn) in enumerate(zip(data, masks, output_files), 1):
        logger.info('Normalizing image: {} ({:d}/{:d})'.format(
            img_fn, i, len(data)))
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        indices = whitestripe(img, contrast, mask=mask)
        normalized = whitestripe_norm(img, indices)
        if write_to_disk:
            logger.info('Saving normalized image: {} ({:d}/{:d})'.format(
                output_fn, i, len(data)))
            io.save_nii(normalized, output_fn)

    # output the last normalized image (mostly for testing purposes)
    return normalized
def main(args=None):
    args = arg_parser().parse_args(args)
    if not (args.brain_mask is None) ^ (args.tissue_mask is None):
        raise NormalizationError(
            'Only one of {brain mask, tissue mask} should be given')
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        if not args.single_img:
            if not os.path.isdir(
                    args.image) or (False if args.brain_mask is None else
                                    not os.path.isdir(args.brain_mask)):
                raise NormalizationError(
                    'if single-img option off, then image and brain-mask must be directories'
                )
            img_fns = io.glob_nii(args.image)
            mask_fns = io.glob_nii(
                args.brain_mask
            ) if args.brain_mask is not None else [None] * len(img_fns)
            if len(img_fns) != len(mask_fns) and len(img_fns) > 0:
                raise NormalizationError(
                    'input images and masks must be in correspondence and greater than zero '
                    '({:d} != {:d})'.format(len(img_fns), len(mask_fns)))
            args.output_dir = args.output_dir or 'fcm'
            output_dir_base = os.path.abspath(
                os.path.join(args.output_dir, '..'))

            if args.contrast.lower() == 't1' and args.tissue_mask is None:
                tissue_mask_dir = os.path.join(output_dir_base, 'tissue_masks')
                if os.path.exists(tissue_mask_dir):
                    logger.warning(
                        'Tissue mask directory already exists, may overwrite existing tissue masks!'
                    )
                else:
                    logger.info('Creating tissue mask directory: {}'.format(
                        tissue_mask_dir))
                    os.mkdir(tissue_mask_dir)
                for i, (img, mask) in enumerate(zip(img_fns, mask_fns), 1):
                    _, base, _ = io.split_filename(img)
                    _, mask_base, _ = io.split_filename(mask)
                    logger.info(
                        'Creating tissue mask for {} ({:d}/{:d})'.format(
                            base, i, len(img_fns)))
                    logger.debug('Tissue mask {} ({:d}/{:d})'.format(
                        mask_base, i, len(img_fns)))
                    process(img, mask, None, tissue_mask_dir, args, logger)
            elif os.path.exists(args.tissue_mask):
                tissue_mask_dir = args.tissue_mask
            else:
                raise NormalizationError(
                    'If contrast is not t1, then tissue mask directory ({}) '
                    'must already be created!'.format(args.tissue_mask))

            tissue_masks = io.glob_nii(tissue_mask_dir)
            for i, (img, tissue_mask) in enumerate(zip(img_fns, tissue_masks),
                                                   1):
                dirname, base, _ = io.split_filename(img)
                _, tissue_base, _ = io.split_filename(tissue_mask)
                logger.info('Normalizing image {} ({:d}/{:d})'.format(
                    base, i, len(img_fns)))
                logger.debug('Tissue mask {} ({:d}/{:d})'.format(
                    tissue_base, i, len(img_fns)))
                if args.output_dir is not None:
                    dirname = args.output_dir
                process(img, None, tissue_mask, dirname, args, logger)

        else:
            if not os.path.isfile(args.image):
                raise NormalizationError(
                    'if single-img option on, then image must be a file')
            if args.tissue_mask is None and args.contrast.lower() == 't1':
                logger.info('Creating tissue mask for {}'.format(args.image))
                process(args.image, args.brain_mask, None, args.output_dir,
                        args, logger)
            elif os.path.isfile(args.tissue_mask):
                pass
            else:
                raise NormalizationError(
                    'If contrast is not t1, then tissue mask must be provided!'
                )
            logger.info('Normalizing image {}'.format(args.image))
            dirname, base, _ = io.split_filename(args.image)
            dirname = args.output_dir or dirname
            if args.tissue_mask is None:
                tissue_mask = os.path.join(
                    dirname, base + '_{}_mask.nii.gz'.format(args.tissue_type))
            else:
                tissue_mask = args.tissue_mask
            process(args.image, args.brain_mask, tissue_mask, dirname, args,
                    logger)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(args.output_dir, args.brain_mask)
            ax.set_title('Fuzzy C-Means')
            plt.savefig(os.path.join(args.output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Esempio n. 11
0
def gmm_class_mask(img,
                   brain_mask=None,
                   contrast='t1',
                   return_wm_peak=True,
                   hard_seg=False):
    """
    get a tissue class mask using gmms (or just the WM peak, for legacy use)

    Args:
        img (nibabel.nifti1.Nifti1Image): target img
        brain_mask (nibabel.nifti1.Nifti1Image): brain mask for img
            (none if already skull-stripped)
        contrast (str): string to describe img's MR contrast
        return_wm_peak (bool): if true, return only the wm peak
        hard_seg (bool): if true and return_wm_peak false, then return
            hard segmentation of tissue classes

    Returns:
        if return_wm_peak true:
            wm_peak (float): represents the mean intensity for WM
        else:
            mask (np.ndarray):
                if hard_seg, then mask is the same size as img
                else, mask is the same size as img * 3, where
                the new dimensions hold the probabilities of tissue class
    """
    img_data = img.get_data()
    if brain_mask is not None:
        mask_data = brain_mask.get_data() > 0
    else:
        mask_data = img_data > img_data.mean()

    brain = np.expand_dims(img_data[mask_data].flatten(), 1)
    gmm = GaussianMixture(3)
    gmm.fit(brain)

    if return_wm_peak:
        means = sorted(gmm.means_.T.squeeze())
        if contrast.lower() == 't1':
            wm_peak = means[2]
        elif contrast.lower() == 'flair':
            wm_peak = means[1]
        elif contrast.lower() == 't2':
            wm_peak = means[0]
        else:
            raise NormalizationError(
                'Invalid contrast type: {}. Must be t1, t2, or flair.'.format(
                    contrast))
        return wm_peak
    else:
        classes = np.argsort(gmm.weights_)
        if hard_seg:
            tmp_predicted = gmm.predict(brain)
            predicted = np.zeros(tmp_predicted.shape)
            for i, c in enumerate(classes):
                predicted[tmp_predicted == c] = i + 1
            mask = np.zeros(img_data.shape)
            mask[mask_data] = predicted + 1
        else:
            predicted_proba = gmm.predict_proba(brain)
            mask = np.zeros((*img_data.shape, 3))
            for i, c in enumerate(classes):
                mask[mask_data, i] = predicted_proba[:, c]
        return mask
def main(args=None):
    args = arg_parser().parse_args(args)
    if args.verbosity == 1:
        level = logging.getLevelName('INFO')
    elif args.verbosity >= 2:
        level = logging.getLevelName('DEBUG')
    else:
        level = logging.getLevelName('WARNING')
    logging.basicConfig(
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level)
    logger = logging.getLogger(__name__)
    try:
        img_fns = io.glob_nii(args.img_dir)
        mask_fns = io.glob_nii(args.mask_dir)
        if len(img_fns) != len(mask_fns) or len(img_fns) == 0:
            raise NormalizationError(
                'Image directory ({}) and mask directory ({}) must contain the same '
                '(positive) number of images!'.format(args.img_dir,
                                                      args.mask_dir))

        logger.info('Normalizing the images according to RAVEL')
        Z, _ = ravel.ravel_normalize(
            args.img_dir,
            args.mask_dir,
            args.contrast,
            do_whitestripe=args.no_whitestripe,
            b=args.num_unwanted_factors,
            membership_thresh=args.control_membership_threshold,
            do_registration=args.no_registration,
            segmentation_smoothness=args.segmentation_smoothness,
            use_fcm=not args.use_atropos,
            sparse_svd=args.sparse_svd,
            csf_masks=args.csf_masks)

        V = ravel.image_matrix(img_fns, args.contrast, masks=mask_fns)
        V_norm = ravel.ravel_correction(V, Z)
        normalized = ravel.image_matrix_to_images(V_norm, img_fns)

        # save the normalized images to disk
        output_dir = os.getcwd(
        ) if args.output_dir is None else args.output_dir
        out_fns = []
        for fn in img_fns:
            _, base, ext = io.split_filename(fn)
            out_fns.append(os.path.join(output_dir, base + '_ravel' + ext))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        for norm, out_fn in zip(normalized, out_fns):
            norm.to_filename(out_fn)

        if args.plot_hist:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore', category=FutureWarning)
                from intensity_normalization.plot.hist import all_hists
                import matplotlib.pyplot as plt
            ax = all_hists(output_dir, args.mask_dir)
            ax.set_title('RAVEL')
            plt.savefig(os.path.join(output_dir, 'hist.png'))

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Esempio n. 13
0
def image_matrix(imgs,
                 contrast,
                 masks=None,
                 do_whitestripe=True,
                 return_ctrl_matrix=False,
                 membership_thresh=0.99,
                 smoothness=0.25,
                 max_ctrl_vox=10000,
                 do_registration=False,
                 ctrl_prob=1,
                 use_fcm=False):
    """
    creates an matrix of images where the rows correspond the the voxels of
    each image and the columns are the images

    Args:
        imgs (list): list of paths to MR images of interest
        contrast (str): contrast of the set of imgs (e.g., T1)
        masks (list or str): list of corresponding brain masks or just one (template) mask
        do_whitestripe (bool): do whitestripe on the images before storing in matrix or nah
        return_ctrl_matrix (bool): return control matrix for imgs (i.e., a subset of V's rows)
        membership_thresh (float): threshold of membership for control voxels (want this very high)
            this option is only used if the registration is turned off
        smoothness (float): smoothness parameter for segmentation for control voxels
            this option is only used if the registration is turned off
        max_ctrl_vox (int): maximum number of control voxels (if too high, everything
            crashes depending on available memory) only used if do_registration is false
        do_registration (bool): register the images together and take the intersection of the csf
            masks (as done in the original paper, note that this takes much longer)
        ctrl_prob (float): given all data, proportion of data labeled as csf to be
            used for intersection (i.e., when do_registration is true)
        use_fcm (bool): use FCM for segmentation instead of atropos (may be less accurate)

    Returns:
        V (np.ndarray): image matrix (rows are voxels, columns are images)
        Vc (np.ndarray): image matrix of control voxels (rows are voxels, columns are images)
            Vc only returned if return_ctrl_matrix is True
    """
    img_shape = io.open_nii(imgs[0]).get_data().shape
    V = np.zeros((int(np.prod(img_shape)), len(imgs)))

    if return_ctrl_matrix:
        ctrl_vox = []

    if masks is None and return_ctrl_matrix:
        raise NormalizationError(
            'Brain masks must be provided if returning control memberships')
    if masks is None:
        masks = [None] * len(imgs)

    for i, (img_fn, mask_fn) in enumerate(zip(imgs, masks)):
        _, base, _ = io.split_filename(img_fn)
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        # do whitestripe on the image before applying RAVEL (if desired)
        if do_whitestripe:
            logger.info('Applying WhiteStripe to image {} ({:d}/{:d})'.format(
                base, i + 1, len(imgs)))
            inds = whitestripe(img, contrast, mask)
            img = whitestripe_norm(img, inds)
        img_data = img.get_data()
        if img_data.shape != img_shape:
            raise NormalizationError(
                'Cannot normalize because image {} needs to have same dimension '
                'as all other images ({} != {})'.format(
                    base, img_data.shape, img_shape))
        V[:, i] = img_data.flatten()
        if return_ctrl_matrix:
            if do_registration and i == 0:
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                verbose = True if logger.getEffectiveLevel(
                ) == logging.getLevelName('DEBUG') else False
                ctrl_masks = []
                reg_imgs = []
                reg_imgs.append(csf.nibabel_to_ants(img))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            elif do_registration and i != 0:
                template = ants.image_read(imgs[0])
                tmask = ants.image_read(masks[0])
                img = csf.nibabel_to_ants(img)
                logger.info(
                    'Starting registration for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                reg_result = ants.registration(template,
                                               img,
                                               type_of_transform='SyN',
                                               mask=tmask,
                                               verbose=verbose)
                img = reg_result['warpedmovout']
                mask = csf.nibabel_to_ants(mask)
                reg_imgs.append(img)
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            else:
                logger.info(
                    'Finding control voxels for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_mask = csf.csf_mask(img,
                                         mask,
                                         contrast=contrast,
                                         csf_thresh=membership_thresh,
                                         mrf=smoothness,
                                         use_fcm=use_fcm)
                if np.sum(ctrl_mask) == 0:
                    raise NormalizationError(
                        'No control voxels found for image ({}) at threshold ({})'
                        .format(base, membership_thresh))
                elif np.sum(ctrl_mask) < 100:
                    logger.warning(
                        'Few control voxels found ({:d}) (potentially a problematic image ({}) or '
                        'threshold ({}) too high)'.format(
                            int(np.sum(ctrl_mask)), base, membership_thresh))
                ctrl_vox.append(img_data[ctrl_mask == 1].flatten())

    if return_ctrl_matrix and not do_registration:
        min_len = min(min(map(len, ctrl_vox)), max_ctrl_vox)
        logger.info('Using {:d} control voxels'.format(min_len))
        Vc = np.zeros((min_len, len(imgs)))
        for i in range(len(imgs)):
            ctrl_voxs = ctrl_vox[i][:min_len]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
    elif return_ctrl_matrix and do_registration:
        ctrl_sum = reduce(
            add,
            ctrl_masks)  # need to use reduce instead of sum b/c data structure
        intersection = np.zeros(ctrl_sum.shape)
        intersection[ctrl_sum >= np.floor(len(ctrl_masks) * ctrl_prob)] = 1
        num_ctrl_vox = int(np.sum(intersection))
        Vc = np.zeros((num_ctrl_vox, len(imgs)))
        for i, img in enumerate(reg_imgs):
            ctrl_voxs = img.numpy()[intersection == 1]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
        del ctrl_masks, reg_imgs
        import gc
        gc.collect(
        )  # force a garbage collection, since we just used the majority of the system memory

    return V if not return_ctrl_matrix else (V, Vc)