def image_matrix(imgs,
                 contrast,
                 masks=None,
                 do_whitestripe=True,
                 return_ctrl_matrix=False,
                 membership_thresh=0.99,
                 smoothness=0.25,
                 max_ctrl_vox=10000,
                 do_registration=False,
                 ctrl_prob=1,
                 use_fcm=False):
    """
    creates an matrix of images where the rows correspond the the voxels of
    each image and the columns are the images

    Args:
        imgs (list): list of paths to MR images of interest
        contrast (str): contrast of the set of imgs (e.g., T1)
        masks (list or str): list of corresponding brain masks or just one (template) mask
        do_whitestripe (bool): do whitestripe on the images before storing in matrix or nah
        return_ctrl_matrix (bool): return control matrix for imgs (i.e., a subset of V's rows)
        membership_thresh (float): threshold of membership for control voxels (want this very high)
            this option is only used if the registration is turned off
        smoothness (float): smoothness parameter for segmentation for control voxels
            this option is only used if the registration is turned off
        max_ctrl_vox (int): maximum number of control voxels (if too high, everything
            crashes depending on available memory) only used if do_registration is false
        do_registration (bool): register the images together and take the intersection of the csf
            masks (as done in the original paper, note that this takes much longer)
        ctrl_prob (float): given all data, proportion of data labeled as csf to be
            used for intersection (i.e., when do_registration is true)
        use_fcm (bool): use FCM for segmentation instead of atropos (may be less accurate)

    Returns:
        V (np.ndarray): image matrix (rows are voxels, columns are images)
        Vc (np.ndarray): image matrix of control voxels (rows are voxels, columns are images)
            Vc only returned if return_ctrl_matrix is True
    """
    img_shape = io.open_nii(imgs[0]).get_data().shape
    V = np.zeros((int(np.prod(img_shape)), len(imgs)))

    if return_ctrl_matrix:
        ctrl_vox = []

    if masks is None and return_ctrl_matrix:
        raise NormalizationError(
            'Brain masks must be provided if returning control memberships')
    if masks is None:
        masks = [None] * len(imgs)

    for i, (img_fn, mask_fn) in enumerate(zip(imgs, masks)):
        _, base, _ = io.split_filename(img_fn)
        img = io.open_nii(img_fn)
        mask = io.open_nii(mask_fn) if mask_fn is not None else None
        # do whitestripe on the image before applying RAVEL (if desired)
        if do_whitestripe:
            logger.info('Applying WhiteStripe to image {} ({:d}/{:d})'.format(
                base, i + 1, len(imgs)))
            inds = whitestripe(img, contrast, mask)
            img = whitestripe_norm(img, inds)
        img_data = img.get_data()
        if img_data.shape != img_shape:
            raise NormalizationError(
                'Cannot normalize because image {} needs to have same dimension '
                'as all other images ({} != {})'.format(
                    base, img_data.shape, img_shape))
        V[:, i] = img_data.flatten()
        if return_ctrl_matrix:
            if do_registration and i == 0:
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                verbose = True if logger.getEffectiveLevel(
                ) == logging.getLevelName('DEBUG') else False
                ctrl_masks = []
                reg_imgs = []
                reg_imgs.append(csf.nibabel_to_ants(img))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            elif do_registration and i != 0:
                template = ants.image_read(imgs[0])
                tmask = ants.image_read(masks[0])
                img = csf.nibabel_to_ants(img)
                logger.info(
                    'Starting registration for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                reg_result = ants.registration(template,
                                               img,
                                               type_of_transform='SyN',
                                               mask=tmask,
                                               verbose=verbose)
                img = reg_result['warpedmovout']
                mask = csf.nibabel_to_ants(mask)
                reg_imgs.append(img)
                logger.info(
                    'Creating control mask for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_masks.append(
                    csf.csf_mask(img,
                                 mask,
                                 contrast=contrast,
                                 csf_thresh=membership_thresh,
                                 mrf=smoothness,
                                 use_fcm=use_fcm))
            else:
                logger.info(
                    'Finding control voxels for image {} ({:d}/{:d})'.format(
                        base, i + 1, len(imgs)))
                ctrl_mask = csf.csf_mask(img,
                                         mask,
                                         contrast=contrast,
                                         csf_thresh=membership_thresh,
                                         mrf=smoothness,
                                         use_fcm=use_fcm)
                if np.sum(ctrl_mask) == 0:
                    raise NormalizationError(
                        'No control voxels found for image ({}) at threshold ({})'
                        .format(base, membership_thresh))
                elif np.sum(ctrl_mask) < 100:
                    logger.warning(
                        'Few control voxels found ({:d}) (potentially a problematic image ({}) or '
                        'threshold ({}) too high)'.format(
                            int(np.sum(ctrl_mask)), base, membership_thresh))
                ctrl_vox.append(img_data[ctrl_mask == 1].flatten())

    if return_ctrl_matrix and not do_registration:
        min_len = min(min(map(len, ctrl_vox)), max_ctrl_vox)
        logger.info('Using {:d} control voxels'.format(min_len))
        Vc = np.zeros((min_len, len(imgs)))
        for i in range(len(imgs)):
            ctrl_voxs = ctrl_vox[i][:min_len]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
    elif return_ctrl_matrix and do_registration:
        ctrl_sum = reduce(
            add,
            ctrl_masks)  # need to use reduce instead of sum b/c data structure
        intersection = np.zeros(ctrl_sum.shape)
        intersection[ctrl_sum >= np.floor(len(ctrl_masks) * ctrl_prob)] = 1
        num_ctrl_vox = int(np.sum(intersection))
        Vc = np.zeros((num_ctrl_vox, len(imgs)))
        for i, img in enumerate(reg_imgs):
            ctrl_voxs = img.numpy()[intersection == 1]
            logger.info(
                'Image {:d} control voxel stats -  mean: {:.3f}, std: {:.3f}'.
                format(i + 1, np.mean(ctrl_voxs), np.std(ctrl_voxs)))
            Vc[:, i] = ctrl_voxs
        del ctrl_masks, reg_imgs
        import gc
        gc.collect(
        )  # force a garbage collection, since we just used the majority of the system memory

    return V if not return_ctrl_matrix else (V, Vc)
def run_intensity_ravel(outfolder):
    from intensity_normalization.normalize import ravel
    from intensity_normalization.utilities import io, csf

    try:
        images = []
        brainMasks = []
        csfMasks = []
        _, _, filenames = next(walk(outfolder))
        for f in filenames:
            filename = f.split(sep)[-1].split(".")[0]
            images.append(io.open_nii(join(outfolder, f.split(sep)[-1])))
            brainMasks.append(
                io.open_nii(
                    join(outfolder, 'robex_masks', filename + "_mask.nii.gz")))

        if not exists(join(outfolder, "Robex")):
            makedirs(join(outfolder, "Robex"))

        if not exists(join(outfolder, "csf_masks")):
            makedirs(join(outfolder, "csf_masks"))

        print("creating csf masks...")
        for image, brainMask, f in zip(images, brainMasks, filenames):
            filename = f.split(sep)[-1].split(".")[0]
            csfMask = csf.csf_mask(image,
                                   brainMask,
                                   contrast='T1',
                                   csf_thresh=0.9,
                                   return_prob=False,
                                   mrf=0.25,
                                   use_fcm=False)
            output = nib.Nifti1Image(csfMask, None)
            io.save_nii(
                output,
                join(outfolder, 'csf_masks', filename + "_csfmask.nii.gz"))
            shutil.move(join(outfolder,
                             f.split(sep)[-1]), join(outfolder, "Robex"))

        print('running intensity ravel...')
        ravel.ravel_normalize(join(outfolder, 'Robex'),
                              join(outfolder, 'csf_masks'),
                              'T1',
                              output_dir=outfolder,
                              write_to_disk=True,
                              do_whitestripe=True,
                              b=1,
                              membership_thresh=0.99,
                              segmentation_smoothness=0.25,
                              do_registration=False,
                              use_fcm=True,
                              sparse_svd=False,
                              csf_masks=True)

        for i in filenames:
            rename(
                join(outfolder,
                     i.split(sep)[-1]),
                join(outfolder,
                     i.split(sep)[-1].split(".")[0] + "_RAVEL.nii.gz"))
    except:
        e = sys.exc_info()
        print("Error: ", str(e[0]))
        print("Error: ", str(e[1]))
        print("Error: executing ravel method")
        sys.exit(2)