def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI":
            print("DTI-based ODF computation")
            self.dti_model = dti.TensorModel(self.dataset.gtab,
                                             fit_method='LS')
            self.dti_fit = self.dti_model.fit(self.dataset.dwi,
                                              mask=self.dataset.binary_mask)
            # compute ODF
            odf = self.dti_fit.odf(self.sphere_odf)
        elif self.odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            num_voxels = np.sum(mask)
            print(num_voxels)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            print(response)
            self.dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            self.dti_fit = self.dti_model.fit(self.dataset.dwi)
            odf = self.dti_fit.odf(self.sphere_odf)

        # -- set up interpolator for odf evaluation
        x_range = np.arange(odf.shape[0])
        y_range = np.arange(odf.shape[1])
        z_range = np.arange(odf.shape[2])

        self.odf_interpolator = RegularGridInterpolator(
            (x_range, y_range, z_range), odf)
Ejemplo n.º 2
0
    def _init_odf(self, odf_mode):
        print("Initialising ODF")
        # fit DTI model to data
        if odf_mode == "DTI":
            print("DTI-based ODF computation")
            dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS')
            dti_fit = dti_model.fit(self.dataset.dwi,
                                    mask=self.dataset.binary_mask)
            # compute ODF
            odf = dti_fit.odf(self.sphere)
        elif odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            dti_fit = dti_model.fit(self.dataset.dwi)
            odf = dti_fit.odf(self.sphere)
        else:
            raise NotImplementedError("ODF mode not found")
        # -- set up interpolator for odf evaluation
        odf = torch.from_numpy(odf).to(device=self.device).float()

        self.odf_interpolator = TorchGridInterpolator(odf)
Ejemplo n.º 3
0
    def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI":
            print("DTI-based ODF computation")
            self.dti_model = dti.TensorModel(self.dataset.gtab,
                                             fit_method='LS')
            self.dti_fit = self.dti_model.fit(self.dataset.dwi,
                                              mask=self.dataset.binary_mask)
            # compute ODF
            odf = self.dti_fit.odf(self.sphere_odf)
        elif self.odf_mode == "CSD":
            print("CSD-based ODF computation")
            mask = mask_for_response_ssst(self.dataset.gtab,
                                          self.dataset.dwi,
                                          roi_radii=10,
                                          fa_thr=0.7)
            num_voxels = np.sum(mask)
            print(num_voxels)
            response, ratio = response_from_mask_ssst(self.dataset.gtab,
                                                      self.dataset.dwi, mask)
            print(response)
            self.dti_model = ConstrainedSphericalDeconvModel(
                self.dataset.gtab, response)
            self.dti_fit = self.dti_model.fit(self.dataset.dwi)
            odf = self.dti_fit.odf(self.sphere_odf)

        # -- set up interpolator for odf evaluation
        odf = torch.from_numpy(odf).to(device=self.device).float()
        self.odf_interpolator = TorchGridInterpolator(odf)
        print("..done!")
Ejemplo n.º 4
0
def test_response_from_mask_ssst():
    gtab, data, mask_gt, response_gt, _ = get_test_data()

    response, _ = response_from_mask_ssst(gtab, data, mask_gt)

    assert_array_almost_equal(response[0], response_gt[0])
    assert_equal(response[1], response_gt[1])
Ejemplo n.º 5
0
    def _init_odf(self):
        print("Initialising ODF")
        # fit DTI model to data
        if self.odf_mode == "DTI" or self.odf_mode == "CSD":
            if self.odf_mode == "DTI":
                print("DTI-based ODF computation")
                self.dti_model = dti.TensorModel(self.dataset.gtab,
                                                 fit_method='LS')
                self.dti_fit = self.dti_model.fit(
                    self.dataset.dwi, mask=self.dataset.binary_mask)
                # compute ODF
                odf = self.dti_fit.odf(self.sphere_odf)
            elif self.odf_mode == "CSD":
                print("CSD-based ODF computation")
                mask = mask_for_response_ssst(self.dataset.gtab,
                                              self.dataset.dwi,
                                              roi_radii=10,
                                              fa_thr=0.7)
                num_voxels = np.sum(mask)
                print(num_voxels)
                response, ratio = response_from_mask_ssst(
                    self.dataset.gtab, self.dataset.dwi, mask)
                print(response)
                self.dti_model = ConstrainedSphericalDeconvModel(
                    self.dataset.gtab, response)
                self.dti_fit = self.dti_model.fit(self.dataset.dwi)
                odf = self.dti_fit.odf(self.sphere_odf)

            # -- set up interpolator for odf evaluation
            x_range = np.arange(odf.shape[0])
            y_range = np.arange(odf.shape[1])
            z_range = np.arange(odf.shape[2])

            self.odf_interpolator = RegularGridInterpolator(
                (x_range, y_range, z_range), odf)

        elif self.odf_mode == "NN":
            print("Neural-Network-based ODF computation")
            print(
                "Warning: The currently used model was trained with 1x1x1 normalised and cropped DWI data from HCP."
            )
            print(
                "Only use this mode if you know that the model is compatible with the DWI data you are using"
            )
            script = torch.jit.load("1x1x1_model3.pt",
                                    map_location=self.device)

            def interpolate(coords_ijk):
                with torch.no_grad():
                    new_shape = (*coords_ijk.shape[:-1], -1)

                    dwi = torch.from_numpy(self.dataset.get_interpolated_dwi(self.dataset.to_ras(coords_ijk),
                                                                             postprocessing=Resample100()))\
                        .float().to(self.device)

                    odf_value = script(dwi).reshape(new_shape)
                    return odf_value.numpy()

            self.odf_interpolator = interpolate
Ejemplo n.º 6
0
    def _setup_odf(self):

        print("Setting up ODF")
        mask = mask_for_response_ssst(self.dataset.gtab, self.dataset.dwi, roi_radii=10, fa_thr=0.7)
        print("Calculating response")
        response, _ = response_from_mask_ssst(self.dataset.gtab, self.dataset.dwi, mask)
        dti_model = ConstrainedSphericalDeconvModel(self.dataset.gtab, response)
        print("Fitting CSD model")
        dti_fit = dti_model.fit(self.dataset.dwi)
        self.odf = dti_fit.odf(self.sphere)
Ejemplo n.º 7
0
def test_auto_response_ssst():
    gtab, data, _, _, _ = get_test_data()

    response_auto, ratio_auto = auto_response_ssst(gtab,
                                                   data,
                                                   roi_center=None,
                                                   roi_radii=(1, 1, 0),
                                                   fa_thr=0.7)

    mask = mask_for_response_ssst(gtab, data,
                                  roi_center=None,
                                  roi_radii=(1, 1, 0),
                                  fa_thr=0.7)

    response_from_mask, ratio_from_mask = response_from_mask_ssst(gtab,
                                                                  data,
                                                                  mask)

    assert_array_equal(response_auto[0], response_from_mask[0])
    assert_equal(response_auto[1], response_from_mask[1])
    assert_array_equal(ratio_auto, ratio_from_mask)
Ejemplo n.º 8
0
"""
Note that the ``auto_response_ssst`` function calls two functions that can be
used separately. First, the function ``mask_for_response_ssst`` creates a mask
of voxels within the cuboid ROI that meet the FA threshold constraint. This
mask can be used to calculate the number of voxels that were kept, or to also
apply an external mask (a WM mask for example). Second, the function
``response_from_mask_ssst`` takes the mask and returns the response function
calculated within the mask. If no changes are made to the mask between the two
calls, the resulting responses should be identical.
"""

mask = mask_for_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
nvoxels = np.sum(mask)
print(nvoxels)

response, ratio = response_from_mask_ssst(gtab, data, mask)
"""
The ``response`` tuple contains two elements. The first is an array with
the eigenvalues of the response function and the second is the average S0 for
this response.

It is good practice to always validate the result of auto_response_ssst. For
this purpose we can print the elements of ``response`` and have a look at their
values.
"""

print(response)
"""
(array([ 0.0014,  0.00029,  0.00029]), 416.206)

The tensor generated from the response must be prolate (two smaller eigenvalues
Ejemplo n.º 9
0
def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, tol=20):
    """ Computation of multi-shell multi-tissue (msmt) response
        functions from given tissues masks.

    Parameters
    ----------
    gtab : GradientTable
    data : ndarray
        diffusion data
    mask_wm : ndarray
        mask from where to compute the WM response function.
    mask_gm : ndarray
        mask from where to compute the GM response function.
    mask_csf : ndarray
        mask from where to compute the CSF response function.
    tol : int
        tolerance gap for b-values clustering. (Default = 20)

    Returns
    -------
    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for WM for each unique bvalues (except b0).
    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for GM for each unique bvalues (except b0).
    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for CSF for each unique bvalues (except b0).

    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. This information can be obtained by using
    mcsd.mask_for_response_msmt() through masks of selected voxels. The present
    function uses such masks to compute the msmt response functions.

    For the responses, we base our approach on the function
    csdeconv.response_from_mask_ssst(), with the added layers of multishell and
    multi-tissue (see the ssst function for more information about the
    computation of the ssst response function). This means that for each tissue
    we use the previously found masks and loop on them. For each mask, we loop
    on the b-values (clustered using the tolerance gap) to get many responses
    and then average them to get one response per tissue.
    """

    bvals = gtab.bvals
    bvecs = gtab.bvecs
    btens = gtab.btens

    list_bvals = unique_bvals_tolerance(bvals, tol)

    b0_indices = get_bval_indices(bvals, list_bvals[0], tol)
    b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis]

    masks = [mask_wm, mask_gm, mask_csf]
    tissue_responses = []
    for mask in masks:
        responses = []
        for bval in list_bvals[1:]:
            indices = get_bval_indices(bvals, bval, tol)

            bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]],
                                        bvecs[indices]])
            bvals_sub = np.concatenate([[0], bvals[indices]])
            if btens is not None:
                btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3))
                btens_sub = np.concatenate([btens_b0, btens[indices]])
            else:
                btens_sub = None

            data_conc = np.concatenate([b0_map, data[..., indices]], axis=3)

            gtab = gradient_table(bvals_sub, bvecs_sub, btens=btens_sub)
            response, _ = response_from_mask_ssst(gtab, data_conc, mask)

            responses.append(list(np.concatenate([response[0],
                                                  [response[1]]])))

        tissue_responses.append(list(responses))

    wm_response = np.asarray(tissue_responses[0])
    gm_response = np.asarray(tissue_responses[1])
    csf_response = np.asarray(tissue_responses[2])
    return wm_response, gm_response, csf_response
Ejemplo n.º 10
0
def compute_ssst_frf(data,
                     bvals,
                     bvecs,
                     mask=None,
                     mask_wm=None,
                     fa_thresh=0.7,
                     min_fa_thresh=0.5,
                     min_nvox=300,
                     roi_radii=10,
                     roi_center=None,
                     force_b0_threshold=False):
    """Compute a single-shell (under b=1500), single-tissue single Fiber
    Response Function from a DWI volume.
    A DTI fit is made, and voxels containing a single fiber population are
    found using a threshold on the FA.

    Parameters
    ----------
    data : ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals : ndarray
        1D bvals array with shape (N,)
    bvecs : ndarray
        2D bvecs array with shape (N, 3)
    mask : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction. Useful if no white matter mask is
        available.
    mask_wm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary white matter mask. Only the data inside this mask and above the
        threshold defined by fa_thresh will be used to estimate the fiber
        response function.
    fa_thresh : float, optional
        Use this threshold as the initial threshold to select single fiber
        voxels. Defaults to 0.7
    min_fa_thresh : float, optional
        Minimal value that will be tried when looking for single fiber voxels.
        Defaults to 0.5
    min_nvox : int, optional
        Minimal number of voxels needing to be identified as single fiber
        voxels in the automatic estimation. Defaults to 300.
    roi_radii : int or array-like (3,), optional
        Use those radii to select a cuboid roi to estimate the FRF. The roi
        will be a cuboid spanning from the middle of the volume in each
        direction with the different radii. Defaults to 10.
    roi_center : tuple(3), optional
        Use this center to span the roi of size roi_radius (center of the
        3D volume).
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    full_reponse : ndarray
        Fiber Response Function, with shape (4,)

    Raises
    ------
    ValueError
        If less than `min_nvox` voxels were found with sufficient FA to
        estimate the FRF.
    """
    if min_fa_thresh < 0.4:
        logging.warning(
            "Minimal FA threshold ({:.2f}) seems really small. "
            "Make sure it makes sense for this dataset.".format(min_fa_thresh))

    if not is_normalized_bvecs(bvecs):
        logging.warning("Your b-vectors do not seem normalized...")
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(force_b0_threshold, bvals.min())

    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    if mask is not None:
        data = applymask(data, mask)

    if mask_wm is not None:
        data = applymask(data, mask_wm)
    else:
        logging.warning(
            "No white matter mask specified! Only mask will be used "
            "(if it has been supplied). \nBe *VERY* careful about the "
            "estimation of the fiber response function to ensure no invalid "
            "voxel was used.")

    # Iteratively trying to fit at least min_nvox voxels. Lower the FA threshold
    # when it doesn't work. Fail if the fa threshold is smaller than
    # the min_threshold.
    # We use an epsilon since the -= 0.05 might incur numerical imprecision.
    nvox = 0
    while nvox < min_nvox and fa_thresh >= min_fa_thresh - 0.00001:
        mask = mask_for_response_ssst(gtab,
                                      data,
                                      roi_center=roi_center,
                                      roi_radii=roi_radii,
                                      fa_thr=fa_thresh)
        nvox = np.sum(mask)
        response, ratio = response_from_mask_ssst(gtab, data, mask)

        logging.debug(
            "Number of indices is {:d} with threshold of {:.2f}".format(
                nvox, fa_thresh))
        fa_thresh -= 0.05

    if nvox < min_nvox:
        raise ValueError(
            "Could not find at least {:d} voxels with sufficient FA "
            "to estimate the FRF!".format(min_nvox))

    logging.debug("Found {:d} voxels with FA threshold {:.2f} for "
                  "FRF estimation".format(nvox, fa_thresh + 0.05))
    logging.debug("FRF eigenvalues: {}".format(str(response[0])))
    logging.debug("Ratio for smallest to largest eigen value "
                  "is {:.3f}".format(ratio))
    logging.debug("Mean of the b=0 signal for voxels used "
                  "for FRF: {}".format(response[1]))

    full_response = np.array(
        [response[0][0], response[0][1], response[0][2], response[1]])

    return full_response