Пример #1
0
def test_mask_for_response_msmt_nvoxels():
    gtab, data, _, _ = get_test_data()

    with warnings.catch_warnings(record=True) as w:
        wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
                                                            roi_center=None,
                                                            roi_radii=(1, 1, 0),
                                                            wm_fa_thr=0.7,
                                                            gm_fa_thr=0.3,
                                                            csf_fa_thr=0.15,
                                                            gm_md_thr=0.001,
                                                            csf_md_thr=0.0032)

    npt.assert_equal(len(w), 1)
    npt.assert_(issubclass(w[0].category, UserWarning))
    npt.assert_("""Some b-values are higher than 1200.""" in
                str(w[0].message))

    wm_nvoxels = np.sum(wm_mask)
    gm_nvoxels = np.sum(gm_mask)
    csf_nvoxels = np.sum(csf_mask)
    npt.assert_equal(wm_nvoxels, 5)
    npt.assert_equal(gm_nvoxels, 2)
    npt.assert_equal(csf_nvoxels, 2)

    with warnings.catch_warnings(record=True) as w:
        wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
                                                            roi_center=None,
                                                            roi_radii=(1, 1, 0),
                                                            wm_fa_thr=1,
                                                            gm_fa_thr=0,
                                                            csf_fa_thr=0,
                                                            gm_md_thr=0,
                                                            csf_md_thr=0)
        npt.assert_equal(len(w), 6)
        npt.assert_(issubclass(w[0].category, UserWarning))
        npt.assert_("""Some b-values are higher than 1200.""" in
                    str(w[0].message))
        npt.assert_("No voxel with a FA higher than 1 were found" in
                    str(w[1].message))
        npt.assert_("No voxel with a FA lower than 0 were found" in
                    str(w[2].message))
        npt.assert_("No voxel with a MD lower than 0 were found" in
                    str(w[3].message))
        npt.assert_("No voxel with a FA lower than 0 were found" in
                    str(w[4].message))
        npt.assert_("No voxel with a MD lower than 0 were found" in
                    str(w[5].message))

    wm_nvoxels = np.sum(wm_mask)
    gm_nvoxels = np.sum(gm_mask)
    csf_nvoxels = np.sum(csf_mask)
    npt.assert_equal(wm_nvoxels, 0)
    npt.assert_equal(gm_nvoxels, 0)
    npt.assert_equal(csf_nvoxels, 0)
Пример #2
0
def _model(gtab, data, response=None, sh_order=None, msmt=False):
    """
    Helper function that defines a CSD model.
    """
    if sh_order is None:
        ndata = np.sum(~gtab.b0s_mask)
        # See dipy.reconst.shm.calculate_max_order
        L1 = (-3 + np.sqrt(1 + 8 * ndata)) / 2.0
        sh_order = int(L1)
        if np.mod(sh_order, 2) != 0:
            sh_order = sh_order - 1
        if sh_order > 8:
            sh_order = 8

    if msmt:
        my_model = mcsd.MultiShellDeconvModel
        if response is None:
            mask_wm, mask_gm, mask_csf =\
                mcsd.mask_for_response_msmt(gtab, data)
            response_wm, response_gm, response_csf =\
                mcsd.response_from_mask_msmt(gtab, data,
                                             mask_wm, mask_gm, mask_csf)
            response = np.array([response_wm, response_gm, response_csf])
    else:
        my_model = csd.ConstrainedSphericalDeconvModel
        if response is None:
            response, _ = csd.auto_response_ssst(gtab,
                                                 data,
                                                 roi_radii=10,
                                                 fa_thr=0.7)

    csdmodel = my_model(gtab, response, sh_order=sh_order)
    return csdmodel
Пример #3
0
def test_auto_response_msmt():
    gtab, data, _, _ = get_test_data()

    with warnings.catch_warnings(record=True) as w:
        response_auto_wm, response_auto_gm, response_auto_csf = \
            auto_response_msmt(gtab, data, tol=20,
                               roi_center=None, roi_radii=(1, 1, 0),
                               wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15,
                               gm_md_thr=0.001, csf_md_thr=0.0032)

        npt.assert_(issubclass(w[0].category, UserWarning))
        npt.assert_("""Some b-values are higher than 1200.
        The DTI fit might be affected. It is advised to use
        mask_for_response_msmt with bvalues lower than 1200, followed by
        response_from_mask_msmt with all bvalues to overcome this."""
                    in str(w[0].message))

        mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab, data,
                                                            roi_center=None,
                                                            roi_radii=(1, 1, 0),
                                                            wm_fa_thr=0.7,
                                                            gm_fa_thr=0.3,
                                                            csf_fa_thr=0.15,
                                                            gm_md_thr=0.001,
                                                            csf_md_thr=0.0032)

        response_from_mask_wm, response_from_mask_gm, response_from_mask_csf = \
            response_from_mask_msmt(gtab, data,
                                    mask_wm, mask_gm, mask_csf,
                                    tol=20)

        npt.assert_array_equal(response_auto_wm, response_from_mask_wm)
        npt.assert_array_equal(response_auto_gm, response_from_mask_gm)
        npt.assert_array_equal(response_auto_csf, response_from_mask_csf)
Пример #4
0
def test_mask_for_response_msmt():
    gtab, data, masks_gt, _ = get_test_data()

    with warnings.catch_warnings(record=True) as w:
        wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
                                                            roi_center=None,
                                                            roi_radii=(1, 1, 0),
                                                            wm_fa_thr=0.7,
                                                            gm_fa_thr=0.3,
                                                            csf_fa_thr=0.15,
                                                            gm_md_thr=0.001,
                                                            csf_md_thr=0.0032)

    npt.assert_equal(len(w), 1)
    npt.assert_(issubclass(w[0].category, UserWarning))
    npt.assert_("""Some b-values are higher than 1200.""" in
                str(w[0].message))

    # Verifies that masks are not empty:
    masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask))
    npt.assert_equal(masks_sum != 0, True)

    npt.assert_array_almost_equal(masks_gt[0], wm_mask)
    npt.assert_array_almost_equal(masks_gt[1], gm_mask)
    npt.assert_array_almost_equal(masks_gt[2], csf_mask)
Пример #5
0
def test_mask_for_response_msmt():
    gtab, data, masks_gt, _ = get_test_data()

    wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab,
                                                        data,
                                                        roi_center=None,
                                                        roi_radii=(1, 1, 0),
                                                        wm_fa_thr=0.7,
                                                        gm_fa_thr=0.3,
                                                        csf_fa_thr=0.15,
                                                        gm_md_thr=0.001,
                                                        csf_md_thr=0.0032)

    # Verifies that masks are not empty:
    masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask))
    npt.assert_equal(masks_sum != 0, True)

    npt.assert_array_almost_equal(masks_gt[0], wm_mask)
    npt.assert_array_almost_equal(masks_gt[1], gm_mask)
    npt.assert_array_almost_equal(masks_gt[2], csf_mask)
Пример #6
0
def compute_msmt_frf(data,
                     bvals,
                     bvecs,
                     data_dti=None,
                     bvals_dti=None,
                     bvecs_dti=None,
                     mask=None,
                     mask_wm=None,
                     mask_gm=None,
                     mask_csf=None,
                     fa_thr_wm=0.7,
                     fa_thr_gm=0.2,
                     fa_thr_csf=0.1,
                     md_thr_gm=0.0007,
                     md_thr_csf=0.003,
                     min_nvox=300,
                     roi_radii=10,
                     roi_center=None,
                     tol=20,
                     force_b0_threshold=False):
    """Compute a single-shell (under b=1500), single-tissue single Fiber
    Response Function from a DWI volume.
    A DTI fit is made, and voxels containing a single fiber population are
    found using a threshold on the FA.

    Parameters
    ----------
    data : ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals : ndarray
        1D bvals array with shape (N,)
    bvecs : ndarray
        2D bvecs array with shape (N, 3)
    mask : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction.
    mask_wm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary white matter mask. Only the data inside this mask will be used
        to estimate the fiber response function of WM.
    mask_gm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary grey matter mask. Only the data inside this mask will be used
        to estimate the fiber response function of GM.
    mask_csf : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary csf mask. Only the data inside this mask will be used to
        estimate the fiber response function of CSF.
    fa_thr_wm : float, optional
        Use this threshold to select single WM fiber voxels from the FA inside
        the WM mask defined by mask_wm. Each voxel above this threshold will be
        selected. Defaults to 0.7
    fa_thr_gm : float, optional
        Use this threshold to select GM voxels from the FA inside the GM mask
        defined by mask_gm. Each voxel below this threshold will be selected.
        Defaults to 0.2
    fa_thr_csf : float, optional
        Use this threshold to select CSF voxels from the FA inside the CSF mask
        defined by mask_csf. Each voxel below this threshold will be selected.
        Defaults to 0.1
    md_thr_gm : float, optional
        Use this threshold to select GM voxels from the MD inside the GM mask
        defined by mask_gm. Each voxel below this threshold will be selected.
        Defaults to 0.0007
    md_thr_csf : float, optional
        Use this threshold to select CSF voxels from the MD inside the CSF mask
        defined by mask_csf. Each voxel below this threshold will be selected.
        Defaults to 0.003
    min_nvox : int, optional
        Minimal number of voxels needing to be identified as single fiber
        voxels in the automatic estimation. Defaults to 300.
    roi_radii : int or array-like (3,), optional
        Use those radii to select a cuboid roi to estimate the FRF. The roi
        will be a cuboid spanning from the middle of the volume in each
        direction with the different radii. Defaults to 10.
    roi_center : tuple(3), optional
        Use this center to span the roi of size roi_radius (center of the
        3D volume).
    tol : int
        tolerance gap for b-values clustering. Defaults to 20
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    reponses : list of ndarray
        Fiber Response Function of each (3) tissue type, with shape (4, N).
    frf_masks : list of ndarray
        Mask where the frf was calculated, for each (3) tissue type, with
        shape (X, Y, Z).

    Raises
    ------
    ValueError
        If less than `min_nvox` voxels were found with sufficient FA to
        estimate the FRF.
    """
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(force_b0_threshold, bvals.min())

    gtab = gradient_table(bvals, bvecs)

    if data_dti is None and bvals_dti is None and bvecs_dti is None:
        logging.warning(
            "No data specific to DTI was given. If b-values go over 1200, "
            "this might produce wrong results.")
        wm_frf_mask, gm_frf_mask, csf_frf_mask \
            = mask_for_response_msmt(gtab, data,
                                     roi_center=roi_center,
                                     roi_radii=roi_radii,
                                     wm_fa_thr=fa_thr_wm,
                                     gm_fa_thr=fa_thr_gm,
                                     csf_fa_thr=fa_thr_csf,
                                     gm_md_thr=md_thr_gm,
                                     csf_md_thr=md_thr_csf)
    elif data_dti is not None and bvals_dti is not None and bvecs_dti is not None:
        if not is_normalized_bvecs(bvecs_dti):
            logging.warning('Your b-vectors do not seem normalized...')
            bvecs_dti = normalize_bvecs(bvecs_dti)

        check_b0_threshold(force_b0_threshold, bvals_dti.min())
        gtab_dti = gradient_table(bvals_dti, bvecs_dti)

        wm_frf_mask, gm_frf_mask, csf_frf_mask \
            = mask_for_response_msmt(gtab_dti, data_dti,
                                     roi_center=roi_center,
                                     roi_radii=roi_radii,
                                     wm_fa_thr=fa_thr_wm,
                                     gm_fa_thr=fa_thr_gm,
                                     csf_fa_thr=fa_thr_csf,
                                     gm_md_thr=md_thr_gm,
                                     csf_md_thr=md_thr_csf)
    else:
        msg = """Input not valid. Either give no _dti input, or give all
        data_dti, bvals_dti and bvecs_dti."""
        raise ValueError(msg)

    if mask is not None:
        wm_frf_mask *= mask
        gm_frf_mask *= mask
        csf_frf_mask *= mask
    if mask_wm is not None:
        wm_frf_mask *= mask_wm
    if mask_gm is not None:
        gm_frf_mask *= mask_gm
    if mask_csf is not None:
        csf_frf_mask *= mask_csf

    msg = """Could not find at least {0} voxels for the {1} mask. Look at
    previous warnings or be sure that external tissue masks overlap with the
    cuboid ROI."""

    if np.sum(wm_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "WM"))
    if np.sum(gm_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "GM"))
    if np.sum(csf_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "CSF"))

    frf_masks = [wm_frf_mask, gm_frf_mask, csf_frf_mask]

    response_wm, response_gm, response_csf \
        = response_from_mask_msmt(gtab, data, wm_frf_mask, gm_frf_mask,
                                  csf_frf_mask, tol=tol)

    responses = [response_wm, response_gm, response_csf]

    return responses, frf_masks
Пример #7
0
def mcsd_mod_est(gtab,
                 data,
                 B0_mask,
                 wm_in_dwi,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 sh_order=8,
                 roi_radii=10):
    """
    Estimate a Constrained Spherical Deconvolution (CSD) model from dwi data.

    Parameters
    ----------
    gtab : Obj
        DiPy object storing diffusion gradient information.
    data : array
        4D numpy array of diffusion image data.
    B0_mask : str
        File path to B0 brain mask.
    sh_order : int
        The order of the SH model. Default is 8.

    Returns
    -------
    csd_mod : ndarray
        Coefficients of the csd reconstruction.
    model : obj
        Fitted csd model.

    References
    ----------
    .. [1] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
      the fibre orientation distribution in diffusion MRI:
      Non-negativity constrained super-resolved spherical
      deconvolution
    .. [2] Descoteaux, M., et al. IEEE TMI 2009. Deterministic and
      Probabilistic Tractography Based on Complex Fibre Orientation
      Distributions
    .. [3] Côté, M-A., et al. Medical Image Analysis 2013. Tractometer:
      Towards validation of tractography pipelines
    .. [4] Tournier, J.D, et al. Imaging Systems and Technology
      2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions

    """
    import dipy.reconst.dti as dti
    from nilearn.image import math_img
    from dipy.core.gradients import unique_bvals_tolerance
    from dipy.reconst.mcsd import (mask_for_response_msmt,
                                   response_from_mask_msmt,
                                   multi_shell_fiber_response,
                                   MultiShellDeconvModel)

    print("Reconstructing using MCSD...")

    B0_mask_data = np.nan_to_num(np.asarray(
        nib.load(B0_mask).dataobj)).astype("bool")

    # Load tissue maps and prepare tissue classifier
    gm_mask_img = math_img("img > 0.10", img=gm_in_dwi)
    gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32)

    wm_mask_img = math_img("img > 0.15", img=wm_in_dwi)
    wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32)

    vent_csf_in_dwi_img = math_img("img > 0.50", img=vent_csf_in_dwi)
    vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj,
                                      dtype=np.float32)

    # Fit a simple DTI model
    tenfit = dti.TensorModel(gtab).fit(data)

    # Obtain the FA and MD metrics
    FA = tenfit.fa
    MD = tenfit.md

    indices_csf = np.where(((FA < 0.2) & (vent_csf_in_dwi_data > 0.50)))
    indices_gm = np.where(((FA < 0.2) & (gm_data > 0.10)))
    indices_wm = np.where(((FA >= 0.2) & (wm_data > 0.15)))

    selected_csf = np.zeros(FA.shape, dtype='bool')
    selected_gm = np.zeros(FA.shape, dtype='bool')
    selected_wm = np.zeros(FA.shape, dtype='bool')

    selected_csf[indices_csf] = True
    selected_gm[indices_gm] = True
    selected_wm[indices_wm] = True

    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
        gtab,
        data,
        roi_radii=roi_radii,
        wm_fa_thr=np.nanmean(FA[selected_wm]),
        gm_fa_thr=np.nanmean(FA[selected_gm]),
        csf_fa_thr=np.nanmean(FA[selected_csf]),
        gm_md_thr=np.nanmean(MD[selected_gm]),
        csf_md_thr=np.nanmean(MD[selected_csf]))

    mask_wm *= wm_data.astype('int64')
    mask_gm *= gm_data.astype('int64')
    mask_csf *= vent_csf_in_dwi_data.astype('int64')

    # nvoxels_wm = np.sum(mask_wm)
    # nvoxels_gm = np.sum(mask_gm)
    # nvoxels_csf = np.sum(mask_csf)

    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf)

    response_mcsd = multi_shell_fiber_response(sh_order=8,
                                               bvals=unique_bvals_tolerance(
                                                   gtab.bvals),
                                               wm_rf=response_wm,
                                               gm_rf=response_gm,
                                               csf_rf=response_csf)

    model = MultiShellDeconvModel(gtab, response_mcsd, sh_order=sh_order)
    mcsd_mod = model.fit(data, B0_mask_data).shm_coeff

    mcsd_mod = np.clip(mcsd_mod, 0, np.max(mcsd_mod, -1)[..., None])
    del response_mcsd, B0_mask_data
    return mcsd_mod.astype("float32"), model
Пример #8
0
``mask_for_response_msmt`` and ``response_from_mask`` is needed.

The ``mask_for_response_msmt`` function will return a mask of voxels within a
cuboid ROI and that meet some threshold constraints, for each tissue and bvalue.
More precisely, the WM mask must have a FA value above a given threshold. The GM
mask and CSF mask must have a FA below given thresholds and a MD below other
thresholds.

Note that for ``mask_for_response_msmt``, the gtab and data should be for
bvalues under 1200, for optimal tensor fit.
"""

mask_wm, mask_gm, mask_csf = mask_for_response_msmt(gtab,
                                                    data,
                                                    roi_radii=100,
                                                    wm_fa_thr=0.7,
                                                    gm_fa_thr=0.3,
                                                    csf_fa_thr=0.15,
                                                    gm_md_thr=0.001,
                                                    csf_md_thr=0.0032)
"""
If one wants to use the previously computed tissue segmentation in addition to
the threshold method, it is possible by simply multiplying both masks together.
"""

mask_wm *= wm
mask_gm *= gm
mask_csf *= csf
"""
The masks can also be used to calculate the number of voxels for each tissue.
"""