Esempio n. 1
0
def test_mask_for_response_ssst_nvoxels():
    gtab, data, _, _, _ = get_test_data()

    mask = mask_for_response_ssst(gtab,
                                  data,
                                  roi_center=None,
                                  roi_radii=(1, 1, 0),
                                  fa_thr=0.7)

    nvoxels = np.sum(mask)
    assert_equal(nvoxels, 5)

    with warnings.catch_warnings(record=True) as w:
        mask = mask_for_response_ssst(gtab,
                                      data,
                                      roi_center=None,
                                      roi_radii=(1, 1, 0),
                                      fa_thr=1)
        npt.assert_equal(len(w), 1)
        npt.assert_(issubclass(w[0].category, UserWarning))
        npt.assert_(
            "No voxel with a FA higher than 1 were found" in str(w[0].message))

    nvoxels = np.sum(mask)
    assert_equal(nvoxels, 0)
    def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI":
            print("DTI-based ODF computation")
            self.dti_model = dti.TensorModel(self.dataset.gtab,
                                             fit_method='LS')
            self.dti_fit = self.dti_model.fit(self.dataset.dwi,
                                              mask=self.dataset.binary_mask)
            # compute ODF
            odf = self.dti_fit.odf(self.sphere_odf)
        elif self.odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            num_voxels = np.sum(mask)
            print(num_voxels)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            print(response)
            self.dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            self.dti_fit = self.dti_model.fit(self.dataset.dwi)
            odf = self.dti_fit.odf(self.sphere_odf)

        # -- set up interpolator for odf evaluation
        x_range = np.arange(odf.shape[0])
        y_range = np.arange(odf.shape[1])
        z_range = np.arange(odf.shape[2])

        self.odf_interpolator = RegularGridInterpolator(
            (x_range, y_range, z_range), odf)
Esempio n. 3
0
    def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI":
            print("DTI-based ODF computation")
            self.dti_model = dti.TensorModel(self.dataset.gtab,
                                             fit_method='LS')
            self.dti_fit = self.dti_model.fit(self.dataset.dwi,
                                              mask=self.dataset.binary_mask)
            # compute ODF
            odf = self.dti_fit.odf(self.sphere_odf)
        elif self.odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            num_voxels = np.sum(mask)
            print(num_voxels)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            print(response)
            self.dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            self.dti_fit = self.dti_model.fit(self.dataset.dwi)
            odf = self.dti_fit.odf(self.sphere_odf)

        # -- set up interpolator for odf evaluation
        odf = torch.from_numpy(odf).to(device=self.device).float()
        self.odf_interpolator = TorchGridInterpolator(odf)
        print("..done!")
Esempio n. 4
0
    def _init_odf(self, odf_mode):
        print("Initialising ODF")
        # fit DTI model to data
        if odf_mode == "DTI":
            print("DTI-based ODF computation")
            dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS')
            dti_fit = dti_model.fit(self.dataset.dwi,
                                    mask=self.dataset.binary_mask)
            # compute ODF
            odf = dti_fit.odf(self.sphere)
        elif odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            dti_fit = dti_model.fit(self.dataset.dwi)
            odf = dti_fit.odf(self.sphere)
        else:
            raise NotImplementedError("ODF mode not found")
        # -- set up interpolator for odf evaluation
        odf = torch.from_numpy(odf).to(device=self.device).float()

        self.odf_interpolator = TorchGridInterpolator(odf)
Esempio n. 5
0
    def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI" or self.odf_mode == "CSD":
            if self.odf_mode == "DTI":
                print("DTI-based ODF computation")
                self.dti_model = dti.TensorModel(self.dataset.gtab,
                                                 fit_method='LS')
                self.dti_fit = self.dti_model.fit(
                    self.dataset.dwi, mask=self.dataset.binary_mask)
                # compute ODF
                odf = self.dti_fit.odf(self.sphere_odf)
            elif self.odf_mode == "CSD":
                print("CSD-based ODF computation")
                mask = mask_for_response_ssst(self.dataset.gtab,
                                              self.dataset.dwi,
                                              roi_radii=10,
                                              fa_thr=0.7)
                num_voxels = np.sum(mask)
                print(num_voxels)
                response, ratio = response_from_mask_ssst(
                    self.dataset.gtab, self.dataset.dwi, mask)
                print(response)
                self.dti_model = ConstrainedSphericalDeconvModel(
                    self.dataset.gtab, response)
                self.dti_fit = self.dti_model.fit(self.dataset.dwi)
                odf = self.dti_fit.odf(self.sphere_odf)

            # -- set up interpolator for odf evaluation
            x_range = np.arange(odf.shape[0])
            y_range = np.arange(odf.shape[1])
            z_range = np.arange(odf.shape[2])

            self.odf_interpolator = RegularGridInterpolator(
                (x_range, y_range, z_range), odf)

        elif self.odf_mode == "NN":
            print("Neural-Network-based ODF computation")
            print(
                "Warning: The currently used model was trained with 1x1x1 normalised and cropped DWI data from HCP."
            )
            print(
                "Only use this mode if you know that the model is compatible with the DWI data you are using"
            )
            script = torch.jit.load("1x1x1_model3.pt",
                                    map_location=self.device)

            def interpolate(coords_ijk):
                with torch.no_grad():
                    new_shape = (*coords_ijk.shape[:-1], -1)

                    dwi = torch.from_numpy(self.dataset.get_interpolated_dwi(self.dataset.to_ras(coords_ijk),
                                                                             postprocessing=Resample100()))\
                        .float().to(self.device)

                    odf_value = script(dwi).reshape(new_shape)
                    return odf_value.numpy()

            self.odf_interpolator = interpolate
Esempio n. 6
0
    def _setup_odf(self):

        print("Setting up ODF")
        mask = mask_for_response_ssst(self.dataset.gtab, self.dataset.dwi, roi_radii=10, fa_thr=0.7)
        print("Calculating response")
        response, _ = response_from_mask_ssst(self.dataset.gtab, self.dataset.dwi, mask)
        dti_model = ConstrainedSphericalDeconvModel(self.dataset.gtab, response)
        print("Fitting CSD model")
        dti_fit = dti_model.fit(self.dataset.dwi)
        self.odf = dti_fit.odf(self.sphere)
Esempio n. 7
0
def test_mask_for_response_ssst():
    gtab, data, mask_gt, _, _ = get_test_data()

    mask = mask_for_response_ssst(gtab, data,
                                  roi_center=None,
                                  roi_radii=(1, 1, 0),
                                  fa_thr=0.7)

    # Verifies that mask is not empty:
    assert_equal(int(np.sum(mask)) != 0, True)

    assert_array_almost_equal(mask_gt, mask)
Esempio n. 8
0
def test_auto_response_ssst():
    gtab, data, _, _, _ = get_test_data()

    response_auto, ratio_auto = auto_response_ssst(gtab,
                                                   data,
                                                   roi_center=None,
                                                   roi_radii=(1, 1, 0),
                                                   fa_thr=0.7)

    mask = mask_for_response_ssst(gtab, data,
                                  roi_center=None,
                                  roi_radii=(1, 1, 0),
                                  fa_thr=0.7)

    response_from_mask, ratio_from_mask = response_from_mask_ssst(gtab,
                                                                  data,
                                                                  mask)

    assert_array_equal(response_auto[0], response_from_mask[0])
    assert_equal(response_auto[1], response_from_mask[1])
    assert_array_equal(ratio_auto, ratio_from_mask)
Esempio n. 9
0
from dipy.reconst.csdeconv import (auto_response_ssst, mask_for_response_ssst,
                                   response_from_mask_ssst)

response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
"""
Note that the ``auto_response_ssst`` function calls two functions that can be
used separately. First, the function ``mask_for_response_ssst`` creates a mask
of voxels within the cuboid ROI that meet the FA threshold constraint. This
mask can be used to calculate the number of voxels that were kept, or to also
apply an external mask (a WM mask for example). Second, the function
``response_from_mask_ssst`` takes the mask and returns the response function
calculated within the mask. If no changes are made to the mask between the two
calls, the resulting responses should be identical.
"""

mask = mask_for_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
nvoxels = np.sum(mask)
print(nvoxels)

response, ratio = response_from_mask_ssst(gtab, data, mask)
"""
The ``response`` tuple contains two elements. The first is an array with
the eigenvalues of the response function and the second is the average S0 for
this response.

It is good practice to always validate the result of auto_response_ssst. For
this purpose we can print the elements of ``response`` and have a look at their
values.
"""

print(response)
Esempio n. 10
0
def compute_ssst_frf(data,
                     bvals,
                     bvecs,
                     mask=None,
                     mask_wm=None,
                     fa_thresh=0.7,
                     min_fa_thresh=0.5,
                     min_nvox=300,
                     roi_radii=10,
                     roi_center=None,
                     force_b0_threshold=False):
    """Compute a single-shell (under b=1500), single-tissue single Fiber
    Response Function from a DWI volume.
    A DTI fit is made, and voxels containing a single fiber population are
    found using a threshold on the FA.

    Parameters
    ----------
    data : ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals : ndarray
        1D bvals array with shape (N,)
    bvecs : ndarray
        2D bvecs array with shape (N, 3)
    mask : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction. Useful if no white matter mask is
        available.
    mask_wm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary white matter mask. Only the data inside this mask and above the
        threshold defined by fa_thresh will be used to estimate the fiber
        response function.
    fa_thresh : float, optional
        Use this threshold as the initial threshold to select single fiber
        voxels. Defaults to 0.7
    min_fa_thresh : float, optional
        Minimal value that will be tried when looking for single fiber voxels.
        Defaults to 0.5
    min_nvox : int, optional
        Minimal number of voxels needing to be identified as single fiber
        voxels in the automatic estimation. Defaults to 300.
    roi_radii : int or array-like (3,), optional
        Use those radii to select a cuboid roi to estimate the FRF. The roi
        will be a cuboid spanning from the middle of the volume in each
        direction with the different radii. Defaults to 10.
    roi_center : tuple(3), optional
        Use this center to span the roi of size roi_radius (center of the
        3D volume).
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    full_reponse : ndarray
        Fiber Response Function, with shape (4,)

    Raises
    ------
    ValueError
        If less than `min_nvox` voxels were found with sufficient FA to
        estimate the FRF.
    """
    if min_fa_thresh < 0.4:
        logging.warning(
            "Minimal FA threshold ({:.2f}) seems really small. "
            "Make sure it makes sense for this dataset.".format(min_fa_thresh))

    if not is_normalized_bvecs(bvecs):
        logging.warning("Your b-vectors do not seem normalized...")
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(force_b0_threshold, bvals.min())

    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    if mask is not None:
        data = applymask(data, mask)

    if mask_wm is not None:
        data = applymask(data, mask_wm)
    else:
        logging.warning(
            "No white matter mask specified! Only mask will be used "
            "(if it has been supplied). \nBe *VERY* careful about the "
            "estimation of the fiber response function to ensure no invalid "
            "voxel was used.")

    # Iteratively trying to fit at least min_nvox voxels. Lower the FA threshold
    # when it doesn't work. Fail if the fa threshold is smaller than
    # the min_threshold.
    # We use an epsilon since the -= 0.05 might incur numerical imprecision.
    nvox = 0
    while nvox < min_nvox and fa_thresh >= min_fa_thresh - 0.00001:
        mask = mask_for_response_ssst(gtab,
                                      data,
                                      roi_center=roi_center,
                                      roi_radii=roi_radii,
                                      fa_thr=fa_thresh)
        nvox = np.sum(mask)
        response, ratio = response_from_mask_ssst(gtab, data, mask)

        logging.debug(
            "Number of indices is {:d} with threshold of {:.2f}".format(
                nvox, fa_thresh))
        fa_thresh -= 0.05

    if nvox < min_nvox:
        raise ValueError(
            "Could not find at least {:d} voxels with sufficient FA "
            "to estimate the FRF!".format(min_nvox))

    logging.debug("Found {:d} voxels with FA threshold {:.2f} for "
                  "FRF estimation".format(nvox, fa_thresh + 0.05))
    logging.debug("FRF eigenvalues: {}".format(str(response[0])))
    logging.debug("Ratio for smallest to largest eigen value "
                  "is {:.3f}".format(ratio))
    logging.debug("Mean of the b=0 signal for voxels used "
                  "for FRF: {}".format(response[1]))

    full_response = np.array(
        [response[0][0], response[0][1], response[0][2], response[1]])

    return full_response