Пример #1
0
def test_unique_bvals_tolerance():
    bvals = np.array([1000, 1000, 1000, 1000, 2000, 2000, 2000, 2000, 0])
    ubvals_gt = np.array([0, 1000, 2000])
    b = unique_bvals_tolerance(bvals)
    npt.assert_array_almost_equal(ubvals_gt, b)

    # Testing the tolerance factor on many b-values that are within tol.
    bvals = np.array([950, 980, 995, 1000, 1000, 1010, 1999, 2000, 2001, 0])
    ubvals_gt = np.array([0, 950, 1000, 2001])
    b = unique_bvals_tolerance(bvals)
    npt.assert_array_almost_equal(ubvals_gt, b)

    # All unique b-values are kept if tolerance is set to zero:
    bvals = np.array([990, 990, 1000, 1000, 2000, 2000, 2050, 2050, 0])
    ubvals_gt = np.array([0, 990, 1000, 2000, 2050])
    b = unique_bvals_tolerance(bvals, 0)
    npt.assert_array_almost_equal(ubvals_gt, b)

    # Case that b-values are in ms/um2
    bvals = np.array(
        [0.995, 0.995, 0.995, 0.995, 2.005, 2.005, 2.005, 2.005, 0])
    b = unique_bvals_tolerance(bvals, 0.5)
    ubvals_gt = np.array([0, 0.995, 2.005])
    npt.assert_array_almost_equal(ubvals_gt, b)
Пример #2
0
def create_mcsd_model(folder_name, data, gtab, labels, sh_order=8):
    from dipy.reconst.mcsd import response_from_mask_msmt
    from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit
    from dipy.core.gradients import unique_bvals_tolerance

    bvals = gtab.bvals
    wm = labels == 3
    gm = labels == 2
    csf = labels == 1

    mask_wm = wm.astype(float)
    mask_gm = gm.astype(float)
    mask_csf = csf.astype(float)

    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf)

    ubvals = unique_bvals_tolerance(bvals)
    response_mcsd = multi_shell_fiber_response(sh_order,
                                               bvals=ubvals,
                                               wm_rf=response_wm,
                                               csf_rf=response_csf,
                                               gm_rf=response_gm)
    mcsd_model = MultiShellDeconvModel(gtab, response_mcsd)

    mcsd_fit = mcsd_model.fit(data)
    sh_coeff = mcsd_fit.all_shm_coeff
    nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
    coeff = mcsd_fit.all_shm_coeff
    n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
    if nan_count > 0:
        print(
            f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
        )
    coeff = np.where(np.isnan(coeff), 0, coeff)
    mcsd_fit = MSDeconvFit(mcsd_model, coeff, None)
    np.save(folder_name + r'\coeff.npy', coeff)

    return mcsd_fit
Пример #3
0
    def _msmt_ft(self):
        from dipy.reconst.mcsd import response_from_mask_msmt
        from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit
        from dipy.core.gradients import unique_bvals_tolerance

        bvals = self.gtab.bvals
        wm = self.tissue_labels == 2
        gm = self.tissue_labels == 1
        csf = self.tissue_labels == 3

        mask_wm = wm.astype(float)
        mask_gm = gm.astype(float)
        mask_csf = csf.astype(float)

        response_wm, response_gm, response_csf = response_from_mask_msmt(
            self.gtab, self.data, mask_wm, mask_gm, mask_csf)

        ubvals = unique_bvals_tolerance(bvals)
        response_mcsd = multi_shell_fiber_response(
            self.parameters_dict['sh_order'],
            bvals=ubvals,
            wm_rf=response_wm,
            csf_rf=response_csf,
            gm_rf=response_gm)
        mcsd_model = MultiShellDeconvModel(self.gtab, response_mcsd)

        mcsd_fit = mcsd_model.fit(self.data)
        sh_coeff = mcsd_fit.all_shm_coeff
        nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
        coeff = mcsd_fit.all_shm_coeff
        n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
        if nan_count > 0:
            print(
                f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
            )
        coeff = np.where(np.isnan(coeff), 0, coeff)
        mcsd_fit = MSDeconvFit(mcsd_model, coeff, None)
        self.model_fit = mcsd_fit
Пример #4
0
denoised_arr = data

tissue_mask = r'F:\Hila\Ax3D_Pack\V6\after_file_prep\YA_lab_Yaniv_002334_20210107_1820\r20210107_182004T1wMPRAGERLs008a1001_brain_seg.nii'
t_mask_img = load_nifti(tissue_mask)[0]
wm = t_mask_img == 3
gm = t_mask_img == 2
csf = t_mask_img == 1

mask_wm = wm.astype(float)
mask_gm = gm.astype(float)
mask_csf = csf.astype(float)

response_wm, response_gm, response_csf = response_from_mask_msmt(
    gtab, data, mask_wm, mask_gm, mask_csf)

ubvals = unique_bvals_tolerance(bvals)
response_mcsd = multi_shell_fiber_response(sh_order=8,
                                           bvals=ubvals,
                                           wm_rf=response_wm,
                                           csf_rf=response_csf,
                                           gm_rf=response_gm)

mcsd_model = MultiShellDeconvModel(gtab, response_mcsd)
mcsd_fit = mcsd_model.fit(denoised_arr)
sh_coeff = mcsd_fit.all_shm_coeff
nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
coeff = mcsd_fit.all_shm_coeff
n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
print(
    f'{nan_count/n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
)
Пример #5
0
def auto_response_msmt(gtab,
                       data,
                       tol=20,
                       roi_center=None,
                       roi_radii=10,
                       wm_fa_thr=0.7,
                       gm_fa_thr=0.3,
                       csf_fa_thr=0.15,
                       gm_md_thr=0.001,
                       csf_md_thr=0.0032):
    """ Automatic estimation of multi-shell multi-tissue (msmt) response
        functions using FA and MD.

    Parameters
    ----------
    gtab : GradientTable
    data : ndarray
        diffusion data
    roi_center : array-like, (3,)
        Center of ROI in data. If center is None, it is assumed that it is
        the center of the volume with shape `data.shape[:3]`.
    roi_radii : int or array-like, (3,)
        radii of cuboid ROI
    wm_fa_thr : float
        FA threshold for WM.
    gm_fa_thr : float
        FA threshold for GM.
    csf_fa_thr : float
        FA threshold for CSF.
    gm_md_thr : float
        MD threshold for GM.
    csf_md_thr : float
        MD threshold for CSF.

    Returns
    -------
    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for WM for each unique bvalues (except b0).
    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for GM for each unique bvalues (except b0).
    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for CSF for each unique bvalues (except b0).

    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. We get this information from
    mcsd.mask_for_response_msmt(), which returns masks of selected voxels
    (more details are available in the description of the function).

    With the masks, we compute the response functions by using
    mcsd.response_from_mask_msmt(), which returns the `response` for each
    tissue (more details are available in the description of the function).
    """

    list_bvals = unique_bvals_tolerance(gtab.bvals)
    if not np.all(list_bvals <= 1200):
        msg_bvals = """Some b-values are higher than 1200.
        The DTI fit might be affected. It is advised to use
        mask_for_response_msmt with bvalues lower than 1200, followed by
        response_from_mask_msmt with all bvalues to overcome this."""
        warnings.warn(msg_bvals, UserWarning)
    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
        gtab, data, roi_center, roi_radii, wm_fa_thr, gm_fa_thr, csf_fa_thr,
        gm_md_thr, csf_md_thr)
    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf, tol)

    return response_wm, response_gm, response_csf
Пример #6
0
def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, tol=20):
    """ Computation of multi-shell multi-tissue (msmt) response
        functions from given tissues masks.

    Parameters
    ----------
    gtab : GradientTable
    data : ndarray
        diffusion data
    mask_wm : ndarray
        mask from where to compute the WM response function.
    mask_gm : ndarray
        mask from where to compute the GM response function.
    mask_csf : ndarray
        mask from where to compute the CSF response function.
    tol : int
        tolerance gap for b-values clustering. (Default = 20)

    Returns
    -------
    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for WM for each unique bvalues (except b0).
    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for GM for each unique bvalues (except b0).
    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for CSF for each unique bvalues (except b0).

    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. This information can be obtained by using
    mcsd.mask_for_response_msmt() through masks of selected voxels. The present
    function uses such masks to compute the msmt response functions.

    For the responses, we base our approach on the function
    csdeconv.response_from_mask_ssst(), with the added layers of multishell and
    multi-tissue (see the ssst function for more information about the
    computation of the ssst response function). This means that for each tissue
    we use the previously found masks and loop on them. For each mask, we loop
    on the b-values (clustered using the tolerance gap) to get many responses
    and then average them to get one response per tissue.
    """

    bvals = gtab.bvals
    bvecs = gtab.bvecs
    btens = gtab.btens

    list_bvals = unique_bvals_tolerance(bvals, tol)

    b0_indices = get_bval_indices(bvals, list_bvals[0], tol)
    b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis]

    masks = [mask_wm, mask_gm, mask_csf]
    tissue_responses = []
    for mask in masks:
        responses = []
        for bval in list_bvals[1:]:
            indices = get_bval_indices(bvals, bval, tol)

            bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]],
                                        bvecs[indices]])
            bvals_sub = np.concatenate([[0], bvals[indices]])
            if btens is not None:
                btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3))
                btens_sub = np.concatenate([btens_b0, btens[indices]])
            else:
                btens_sub = None

            data_conc = np.concatenate([b0_map, data[..., indices]], axis=3)

            gtab = gradient_table(bvals_sub, bvecs_sub, btens=btens_sub)
            response, _ = response_from_mask_ssst(gtab, data_conc, mask)

            responses.append(list(np.concatenate([response[0],
                                                  [response[1]]])))

        tissue_responses.append(list(responses))

    wm_response = np.asarray(tissue_responses[0])
    gm_response = np.asarray(tissue_responses[1])
    csf_response = np.asarray(tissue_responses[2])
    return wm_response, gm_response, csf_response
Пример #7
0
def mask_for_response_msmt(gtab,
                           data,
                           roi_center=None,
                           roi_radii=10,
                           wm_fa_thr=0.7,
                           gm_fa_thr=0.2,
                           csf_fa_thr=0.1,
                           gm_md_thr=0.0007,
                           csf_md_thr=0.002):
    """ Computation of masks for multi-shell multi-tissue (msmt) response
        function using FA and MD.

    Parameters
    ----------
    gtab : GradientTable
    data : ndarray
        diffusion data (4D)
    roi_center : array-like, (3,)
        Center of ROI in data. If center is None, it is assumed that it is
        the center of the volume with shape `data.shape[:3]`.
    roi_radii : int or array-like, (3,)
        radii of cuboid ROI
    wm_fa_thr : float
        FA threshold for WM.
    gm_fa_thr : float
        FA threshold for GM.
    csf_fa_thr : float
        FA threshold for CSF.
    gm_md_thr : float
        MD threshold for GM.
    csf_md_thr : float
        MD threshold for CSF.

    Returns
    -------
    mask_wm : ndarray
        Mask of voxels within the ROI and with FA above the FA threshold
        for WM.
    mask_gm : ndarray
        Mask of voxels within the ROI and with FA below the FA threshold
        for GM and with MD below the MD threshold for GM.
    mask_csf : ndarray
        Mask of voxels within the ROI and with FA below the FA threshold
        for CSF and with MD below the MD threshold for CSF.

    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. This function aims to accomplish that by
    returning a mask of voxels within a ROI and who respect some threshold
    constraints, for each tissue. More precisely, the WM mask must have a FA
    value above a given threshold. The GM mask and CSF mask must have a FA
    below given thresholds and a MD below other thresholds. To get the FA and
    MD, we need to fit a Tensor model to the datasets.
    """

    if len(data.shape) < 4:
        msg = """Data must be 4D (3D image + directions). To use a 2D image,
        please reshape it into a (N, N, 1, ndirs) array."""
        raise ValueError(msg)

    if isinstance(roi_radii, numbers.Number):
        roi_radii = (roi_radii, roi_radii, roi_radii)

    if roi_center is None:
        roi_center = np.array(data.shape[:3]) // 2

    roi_radii = _roi_in_volume(data.shape, np.asarray(roi_center),
                               np.asarray(roi_radii))

    roi_mask = _mask_from_roi(data.shape[:3], roi_center, roi_radii)

    list_bvals = unique_bvals_tolerance(gtab.bvals)
    if not np.all(list_bvals <= 1200):
        msg_bvals = """Some b-values are higher than 1200.
        The DTI fit might be affected."""
        warnings.warn(msg_bvals, UserWarning)

    ten = TensorModel(gtab)
    tenfit = ten.fit(data, mask=roi_mask)
    fa = fractional_anisotropy(tenfit.evals)
    fa[np.isnan(fa)] = 0
    md = mean_diffusivity(tenfit.evals)
    md[np.isnan(md)] = 0

    mask_wm = np.zeros(fa.shape, dtype=np.int64)
    mask_wm[fa > wm_fa_thr] = 1
    mask_wm *= roi_mask

    md_mask_gm = np.zeros(md.shape, dtype=np.int64)
    md_mask_gm[(md < gm_md_thr)] = 1

    fa_mask_gm = np.zeros(fa.shape, dtype=np.int64)
    fa_mask_gm[(fa < gm_fa_thr) & (fa > 0)] = 1

    mask_gm = md_mask_gm * fa_mask_gm
    mask_gm *= roi_mask

    md_mask_csf = np.zeros(md.shape, dtype=np.int64)
    md_mask_csf[(md < csf_md_thr) & (md > 0)] = 1

    fa_mask_csf = np.zeros(fa.shape, dtype=np.int64)
    fa_mask_csf[(fa < csf_fa_thr) & (fa > 0)] = 1

    mask_csf = md_mask_csf * fa_mask_csf
    mask_csf *= roi_mask

    msg = """No voxel with a {0} than {1} were found.
    Try a larger roi or a {2} threshold for {3}."""

    if np.sum(mask_wm) == 0:
        msg_fa = msg.format('FA higher', str(wm_fa_thr), 'lower FA', 'WM')
        warnings.warn(msg_fa, UserWarning)

    if np.sum(mask_gm) == 0:
        msg_fa = msg.format('FA lower', str(gm_fa_thr), 'higher FA', 'GM')
        msg_md = msg.format('MD lower', str(gm_md_thr), 'higher MD', 'GM')
        warnings.warn(msg_fa, UserWarning)
        warnings.warn(msg_md, UserWarning)

    if np.sum(mask_csf) == 0:
        msg_fa = msg.format('FA lower', str(csf_fa_thr), 'higher FA', 'CSF')
        msg_md = msg.format('MD lower', str(csf_md_thr), 'higher MD', 'CSF')
        warnings.warn(msg_fa, UserWarning)
        warnings.warn(msg_md, UserWarning)

    return mask_wm, mask_gm, mask_csf
Пример #8
0
    def __init__(self,
                 gtab,
                 response,
                 reg_sphere=default_sphere,
                 sh_order=8,
                 iso=2):
        r"""
        Multi-Shell Multi-Tissue Constrained Spherical Deconvolution
        (MSMT-CSD) [1]_. This method extends the CSD model proposed in [2]_ by
        the estimation of multiple response functions as a function of multiple
        b-values and multiple tissue types.

        Spherical deconvolution computes a fiber orientation distribution
        (FOD), also called fiber ODF (fODF) [2]_. The fODF is derived from
        different tissue types and thus overcomes the overestimation of WM in
        GM and CSF areas.

        The response function is based on the different tissue types
        and is provided as input to the MultiShellDeconvModel.
        It will be used as deconvolution kernel, as described in [2]_.

        Parameters
        ----------
        gtab : GradientTable
        response : ndarray or MultiShellResponse object
            Pre-computed multi-shell fiber response function in the form of a
            MultiShellResponse object, or simple response function as a ndarray.
            The later must be of shape (3, len(bvals)-1, 4), because it will be
            converted into a MultiShellResponse object via the
            `multi_shell_fiber_response` method (important note: the function
            `unique_bvals_tolerance` is used here to select unique bvalues from
            gtab as input). Each column (3,) has two elements. The first is the
            eigen-values as a (3,) ndarray and the second is the signal value
            for the response function without diffusion weighting (S0). Note
            that in order to use more than three compartments, one must create
            a MultiShellResponse object on the side.
        reg_sphere : Sphere (optional)
            sphere used to build the regularization B matrix.
            Default: 'symmetric362'.
        sh_order : int (optional)
            maximal spherical harmonics order. Default: 8
        iso: int (optional)
            Number of tissue compartments for running the MSMT-CSD. Minimum
            number of compartments required is 2.
            Default: 2

        References
        ----------
        .. [1] Jeurissen, B., et al. NeuroImage 2014. Multi-tissue constrained
               spherical deconvolution for improved analysis of multi-shell
               diffusion MRI data
        .. [2] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
               the fibre orientation distribution in diffusion MRI:
               Non-negativity constrained super-resolved spherical
               deconvolution
        .. [3] Tournier, J.D, et al. Imaging Systems and Technology
               2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions
        """
        if not iso >= 2:
            msg = ("Multi-tissue CSD requires at least 2 tissue compartments")
            raise ValueError(msg)

        super(MultiShellDeconvModel, self).__init__(gtab)

        if not isinstance(response, MultiShellResponse):
            bvals = unique_bvals_tolerance(gtab.bvals, tol=20)
            if iso > 2:
                msg = """Too many compartments for this kind of response
                input. It must be two tissue compartments."""
                raise ValueError(msg)
            if response.shape != (3, len(bvals) - 1, 4):
                msg = """Response must be of shape (3, len(bvals)-1, 4) or be a
                MultiShellResponse object."""
                raise ValueError(msg)
            response = multi_shell_fiber_response(sh_order,
                                                  bvals=bvals,
                                                  wm_rf=response[0],
                                                  gm_rf=response[1],
                                                  csf_rf=response[2])

        B, m, n = multi_tissue_basis(gtab, sh_order, iso)

        delta = _basic_delta(response.iso, response.m, response.n, 0., 0.)
        self.delta = delta
        multiplier_matrix = _inflate_response(response, gtab, n, delta)

        r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T)
        odf_reg, _, _ = shm.real_sh_descoteaux(sh_order, theta, phi)
        reg = np.zeros([i + iso for i in odf_reg.shape])
        reg[:iso, :iso] = np.eye(iso)
        reg[iso:, iso:] = odf_reg

        X = B * multiplier_matrix

        self.fitter = QpFitter(X, reg)
        self.sh_order = sh_order
        self._X = X
        self.sphere = reg_sphere
        self.gtab = gtab
        self.B_dwi = B
        self.m = m
        self.n = n
        self.response = response
Пример #9
0
def mcsd_mod_est(gtab,
                 data,
                 B0_mask,
                 wm_in_dwi,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 sh_order=8,
                 roi_radii=10):
    """
    Estimate a Constrained Spherical Deconvolution (CSD) model from dwi data.

    Parameters
    ----------
    gtab : Obj
        DiPy object storing diffusion gradient information.
    data : array
        4D numpy array of diffusion image data.
    B0_mask : str
        File path to B0 brain mask.
    sh_order : int
        The order of the SH model. Default is 8.

    Returns
    -------
    csd_mod : ndarray
        Coefficients of the csd reconstruction.
    model : obj
        Fitted csd model.

    References
    ----------
    .. [1] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
      the fibre orientation distribution in diffusion MRI:
      Non-negativity constrained super-resolved spherical
      deconvolution
    .. [2] Descoteaux, M., et al. IEEE TMI 2009. Deterministic and
      Probabilistic Tractography Based on Complex Fibre Orientation
      Distributions
    .. [3] Côté, M-A., et al. Medical Image Analysis 2013. Tractometer:
      Towards validation of tractography pipelines
    .. [4] Tournier, J.D, et al. Imaging Systems and Technology
      2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions

    """
    import dipy.reconst.dti as dti
    from nilearn.image import math_img
    from dipy.core.gradients import unique_bvals_tolerance
    from dipy.reconst.mcsd import (mask_for_response_msmt,
                                   response_from_mask_msmt,
                                   multi_shell_fiber_response,
                                   MultiShellDeconvModel)

    print("Reconstructing using MCSD...")

    B0_mask_data = np.nan_to_num(np.asarray(
        nib.load(B0_mask).dataobj)).astype("bool")

    # Load tissue maps and prepare tissue classifier
    gm_mask_img = math_img("img > 0.10", img=gm_in_dwi)
    gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32)

    wm_mask_img = math_img("img > 0.15", img=wm_in_dwi)
    wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32)

    vent_csf_in_dwi_img = math_img("img > 0.50", img=vent_csf_in_dwi)
    vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj,
                                      dtype=np.float32)

    # Fit a simple DTI model
    tenfit = dti.TensorModel(gtab).fit(data)

    # Obtain the FA and MD metrics
    FA = tenfit.fa
    MD = tenfit.md

    indices_csf = np.where(((FA < 0.2) & (vent_csf_in_dwi_data > 0.50)))
    indices_gm = np.where(((FA < 0.2) & (gm_data > 0.10)))
    indices_wm = np.where(((FA >= 0.2) & (wm_data > 0.15)))

    selected_csf = np.zeros(FA.shape, dtype='bool')
    selected_gm = np.zeros(FA.shape, dtype='bool')
    selected_wm = np.zeros(FA.shape, dtype='bool')

    selected_csf[indices_csf] = True
    selected_gm[indices_gm] = True
    selected_wm[indices_wm] = True

    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
        gtab,
        data,
        roi_radii=roi_radii,
        wm_fa_thr=np.nanmean(FA[selected_wm]),
        gm_fa_thr=np.nanmean(FA[selected_gm]),
        csf_fa_thr=np.nanmean(FA[selected_csf]),
        gm_md_thr=np.nanmean(MD[selected_gm]),
        csf_md_thr=np.nanmean(MD[selected_csf]))

    mask_wm *= wm_data.astype('int64')
    mask_gm *= gm_data.astype('int64')
    mask_csf *= vent_csf_in_dwi_data.astype('int64')

    # nvoxels_wm = np.sum(mask_wm)
    # nvoxels_gm = np.sum(mask_gm)
    # nvoxels_csf = np.sum(mask_csf)

    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf)

    response_mcsd = multi_shell_fiber_response(sh_order=8,
                                               bvals=unique_bvals_tolerance(
                                                   gtab.bvals),
                                               wm_rf=response_wm,
                                               gm_rf=response_gm,
                                               csf_rf=response_csf)

    model = MultiShellDeconvModel(gtab, response_mcsd, sh_order=sh_order)
    mcsd_mod = model.fit(data, B0_mask_data).shm_coeff

    mcsd_mod = np.clip(mcsd_mod, 0, np.max(mcsd_mod, -1)[..., None])
    del response_mcsd, B0_mask_data
    return mcsd_mod.astype("float32"), model
Пример #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.wm_out_fODF = args.wm_out_fODF or 'wm_fodf.nii.gz'
        args.gm_out_fODF = args.gm_out_fODF or 'gm_fodf.nii.gz'
        args.csf_out_fODF = args.csf_out_fODF or 'csf_fodf.nii.gz'
        args.vf = args.vf or 'vf.nii.gz'
        args.vf_rgb = args.vf_rgb or 'vf_rgb.nii.gz'

    arglist = [args.wm_out_fODF, args.gm_out_fODF, args.csf_out_fODF,
               args.vf, args.vf_rgb]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec,
                                 args.in_wm_frf, args.in_gm_frf,
                                 args.in_csf_frf])
    assert_outputs_exist(parser, args, arglist)

    # Loading data
    wm_frf = np.loadtxt(args.in_wm_frf)
    gm_frf = np.loadtxt(args.in_gm_frf)
    csf_frf = np.loadtxt(args.in_csf_frf)
    vol = nib.load(args.in_dwi)
    data = vol.get_fdata(dtype=np.float32)
    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    # Checking mask
    if args.mask is None:
        mask = None
    else:
        mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
        if mask.shape != data.shape[:-1]:
            raise ValueError("Mask is not the same shape as data.")

    sh_order = args.sh_order

    # Checking data and sh_order
    b0_thr = check_b0_threshold(
        args.force_b0_threshold, bvals.min(), bvals.min())
    if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2:
        logging.warning(
            'We recommend having at least {} unique DWIs volumes, but you '
            'currently have {} volumes. Try lowering the parameter --sh_order '
            'in case of non convergence.'.format(
                (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))

    # Checking bvals, bvecs values and loading gtab
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)
    gtab = gradient_table(bvals, bvecs, b0_threshold=b0_thr)

    # Checking response functions and computing msmt response function
    if not wm_frf.shape[1] == 4:
        raise ValueError('WM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not gm_frf.shape[1] == 4:
        raise ValueError('GM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not csf_frf.shape[1] == 4:
        raise ValueError('CSF frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    ubvals = unique_bvals_tolerance(bvals, tol=20)
    msmt_response = multi_shell_fiber_response(sh_order, ubvals,
                                               wm_frf, gm_frf, csf_frf)

    # Loading spheres
    reg_sphere = get_sphere('symmetric362')

    # Computing msmt-CSD
    msmt_model = MultiShellDeconvModel(gtab, msmt_response,
                                       reg_sphere=reg_sphere,
                                       sh_order=sh_order)

    # Computing msmt-CSD fit
    msmt_fit = fit_from_model(msmt_model, data,
                              mask=mask, nbr_processes=args.nbr_processes)

    shm_coeff = msmt_fit.all_shm_coeff

    nan_count = len(np.argwhere(np.isnan(shm_coeff[..., 0])))
    voxel_count = np.prod(shm_coeff.shape[:-1])

    if nan_count / voxel_count >= 0.05:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver, reaching a critical amount of voxels. Make sure to tune the
        response functions properly, as the solving process is very sensitive
        to it. Proceeding to fill the problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))
    elif nan_count > 0:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver. Make sure to tune the response functions properly, as the
        solving process is very sensitive to it. Proceeding to fill the
        problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))

    shm_coeff = np.where(np.isnan(shm_coeff), 0, shm_coeff)

    # Saving results
    if args.wm_out_fODF:
        wm_coeff = shm_coeff[..., 2:]
        if args.sh_basis == 'tournier07':
            wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32),
                                 vol.affine), args.wm_out_fODF)

    if args.gm_out_fODF:
        gm_coeff = shm_coeff[..., 1]
        if args.sh_basis == 'tournier07':
            gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,))
            gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32),
                                 vol.affine), args.gm_out_fODF)

    if args.csf_out_fODF:
        csf_coeff = shm_coeff[..., 0]
        if args.sh_basis == 'tournier07':
            csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,))
            csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask,
                                         nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32),
                                 vol.affine), args.csf_out_fODF)

    if args.vf:
        nib.save(nib.Nifti1Image(msmt_fit.volume_fractions.astype(np.float32),
                                 vol.affine), args.vf)

    if args.vf_rgb:
        vf = msmt_fit.volume_fractions
        vf_rgb = vf / np.max(vf) * 255
        vf_rgb = np.clip(vf_rgb, 0, 255)
        nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8),
                                 vol.affine), args.vf_rgb)
Пример #11
0
print(response_csf)
print("Auto responses")
print(auto_response_wm)
print(auto_response_gm)
print(auto_response_csf)
"""
At this point, there are two options on how to use those response functions. We
want to create a MultiShellDeconvModel, which takes a response function as
input. This response function can either be directly in the current format, or
it can be a MultiShellResponse format, as produced by the
``multi_shell_fiber_response`` method. This function assumes a 3 compartments
model (wm, gm, csf) and takes one response function per tissue per bvalue. It is
important to note that the bvalues must be unique for this function.
"""

ubvals = unique_bvals_tolerance(gtab.bvals)
response_mcsd = multi_shell_fiber_response(sh_order=8,
                                           bvals=ubvals,
                                           wm_rf=response_wm,
                                           gm_rf=response_gm,
                                           csf_rf=response_csf)
"""
As mentionned, we can also build the model directly and it will call
``multi_shell_fiber_response`` internally. Important note here, the function
``unique_bvals_tolerance`` is used to keep only unique bvalues from the gtab
given to the model, as input for ``multi_shell_fiber_response``. This may
introduce differences between the calculted response of each method, depending
on the bvalues given to ``multi_shell_fiber_response`` externally.
"""

response = np.array([response_wm, response_gm, response_csf])
Пример #12
0
def main():

    parser = buildArgsParser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec])
    assert_outputs_exist(parser, args, [args.out_wm_frf, args.out_gm_frf,
                                        args.out_csf_frf])

    if len(args.roi_radii) == 1:
        roi_radii = args.roi_radii[0]
    elif len(args.roi_radii) == 2:
        parser.error('--roi_radii cannot be of size (2,).')
    else:
        roi_radii = args.roi_radii
    roi_center = args.roi_center

    vol = nib.load(args.in_dwi)
    data = vol.get_fdata(dtype=np.float32)
    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    tol = args.tolerance
    dti_lim = args.dti_bval_limit

    list_bvals = unique_bvals_tolerance(bvals, tol=tol)
    if not np.all(list_bvals <= dti_lim):
        outputs = extract_dwi_shell(vol, bvals, bvecs,
                                    list_bvals[list_bvals <= dti_lim],
                                    tol=tol)
        _, data_dti, bvals_dti, bvecs_dti = outputs
        bvals_dti = np.squeeze(bvals_dti)
    else:
        data_dti = None
        bvals_dti = None
        bvecs_dti = None

    mask = None
    if args.mask is not None:
        mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
        if mask.shape != data.shape[:-1]:
            raise ValueError("Mask is not the same shape as data.")
    mask_wm = None
    mask_gm = None
    mask_csf = None
    if args.mask_wm:
        mask_wm = get_data_as_mask(nib.load(args.mask_wm), dtype=bool)
    if args.mask_gm:
        mask_gm = get_data_as_mask(nib.load(args.mask_gm), dtype=bool)
    if args.mask_csf:
        mask_csf = get_data_as_mask(nib.load(args.mask_csf), dtype=bool)

    force_b0_thr = args.force_b0_threshold
    responses, frf_masks = compute_msmt_frf(data, bvals, bvecs,
                                            data_dti=data_dti,
                                            bvals_dti=bvals_dti,
                                            bvecs_dti=bvecs_dti,
                                            mask=mask, mask_wm=mask_wm,
                                            mask_gm=mask_gm, mask_csf=mask_csf,
                                            fa_thr_wm=args.fa_thr_wm,
                                            fa_thr_gm=args.fa_thr_gm,
                                            fa_thr_csf=args.fa_thr_csf,
                                            md_thr_gm=args.md_thr_gm,
                                            md_thr_csf=args.md_thr_csf,
                                            min_nvox=args.min_nvox,
                                            roi_radii=roi_radii,
                                            roi_center=roi_center,
                                            tol=tol,
                                            force_b0_threshold=force_b0_thr)

    masks_files = [args.wm_frf_mask, args.gm_frf_mask, args.csf_frf_mask]
    for mask, mask_file in zip(frf_masks, masks_files):
        if mask_file:
            nib.save(nib.Nifti1Image(mask.astype(np.uint8), vol.affine),
                     mask_file)

    frf_out = [args.out_wm_frf, args.out_gm_frf, args.out_csf_frf]

    for frf, response in zip(frf_out, responses):
        np.savetxt(frf, response)

    if args.frf_table:
        if list_bvals[0] < tol:
            bvals = list_bvals[1:]
        else:
            bvals = list_bvals
        response_csf = responses[2]
        response_gm = responses[1]
        response_wm = responses[0]
        iso_responses = np.concatenate((response_csf[:, :3],
                                        response_gm[:, :3]), axis=1)
        responses = np.concatenate((iso_responses, response_wm[:, :3]), axis=1)
        frf_table = np.vstack((bvals, responses.T)).T
        np.savetxt(args.frf_table, frf_table)
Пример #13
0
def generate_kernel(gtab, sphere, wm_response, gm_response, csf_response):
    '''
    Generate deconvolution kernel

    Compute kernel mapping orientation densities of white matter fiber
    populations (along each vertex of the sphere) and isotropic volume
    fractions to a diffusion weighted signal.

    Parameters
    ----------
    gtab : GradientTable
    sphere : Sphere
        Sphere with which to sample discrete fiber orientations in order to
        construct kernel
    wm_response : 1d ndarray or 2d ndarray or AxSymShResponse, optional
        Tensor eigenvalues as a (3,) ndarray, multishell eigenvalues as
        a (len(unique_bvals_tolerance(gtab.bvals))-1, 3) ndarray in
        order of smallest to largest b-value, or an AxSymShResponse.
    gm_response : float, optional
        Mean diffusivity for GM compartment. If `None`, then grey
        matter compartment set to all zeros.
    csf_response : float, optional
        Mean diffusivity for CSF compartment. If `None`, then CSF
        compartment set to all zeros.

    Returns
    -------
    kernel : 2d ndarray (N, M)
        Computed kernel; can be multiplied with a vector consisting of volume
        fractions for each of M-2 fiber populations as well as GM and CSF
        fractions to produce a diffusion weighted signal.
    '''

    # Coordinates of sphere vertices
    sticks = sphere.vertices

    n_grad = len(gtab.gradients)  # number of gradient directions
    n_wm_comp = sticks.shape[0]  # number of fiber populations
    n_comp = n_wm_comp + 2  # plus isotropic compartments

    kernel = np.zeros((n_grad, n_comp))

    # White matter compartments
    list_bvals = unique_bvals_tolerance(gtab.bvals)
    n_bvals = len(list_bvals) - 1  # number of unique b-values

    if isinstance(wm_response, AxSymShResponse):
        # Data-driven response
        where_dwi = lazy_index(~gtab.b0s_mask)
        gradients = gtab.gradients[where_dwi]
        gradients = gradients / np.linalg.norm(gradients, axis=1)[..., None]
        S0 = wm_response.S0
        for i in range(n_wm_comp):
            # Response oriented along [0, 0, 1], so must rotate sticks[i]
            rot_mat = vec2vec_rotmat(sticks[i], np.array([0, 0, 1]))
            rot_gradients = np.dot(rot_mat, gradients.T).T
            rot_sphere = Sphere(xyz=rot_gradients)
            # Project onto rotated sphere and scale
            rot_response = wm_response.on_sphere(rot_sphere) / S0
            kernel[where_dwi, i] = rot_response

        # Set b0 components
        kernel[gtab.b0s_mask, :] = 1

    elif wm_response.shape == (n_bvals, 3):
        # Multi-shell response
        bvals = gtab.bvals
        bvecs = gtab.bvecs
        for n, bval in enumerate(list_bvals[1:]):
            indices = get_bval_indices(bvals, bval)
            with warnings.catch_warnings():  # extract relevant b-value
                warnings.simplefilter("ignore")
                gtab_sub = gradient_table(bvals[indices], bvecs[indices])

            for i in range(n_wm_comp):
                # Signal generated by WM-fiber for each gradient direction
                S = single_tensor(gtab_sub,
                                  evals=wm_response[n],
                                  evecs=all_tensor_evecs(sticks[i]))
                kernel[indices, i] = S

        # Set b0 components
        b0_indices = get_bval_indices(bvals, list_bvals[0])
        kernel[b0_indices, :] = 1

    else:
        # Single-shell response
        for i in range(n_wm_comp):
            # Signal generated by WM-fiber for each gradient direction
            S = single_tensor(gtab,
                              evals=wm_response,
                              evecs=all_tensor_evecs(sticks[i]))
            kernel[:, i] = S

        # Set b0 components
        kernel[gtab.b0s_mask, :] = 1

    # GM compartment
    if gm_response is None:
        S_gm = np.zeros((n_grad))
    else:
        S_gm = \
            single_tensor(gtab, evals=np.array(
                [gm_response, gm_response, gm_response]))

    if csf_response is None:
        S_csf = np.zeros((n_grad))
    else:
        S_csf = \
            single_tensor(gtab, evals=np.array(
                [csf_response, csf_response, csf_response]))

    kernel[:, n_comp - 2] = S_gm
    kernel[:, n_comp - 1] = S_csf

    return kernel