예제 #1
0
def test_multi_shell_fiber_response():

    sh_order = 8
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                              wm_response,
                                              gm_response,
                                              csf_response)

    npt.assert_equal(response.response.shape, (4, 7))

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always", category=PendingDeprecationWarning)
        response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500],
                                              wm_response,
                                              gm_response,
                                              csf_response)
        # Test that the number of warnings raised is greater than 1, with
        # deprecation warnings being raised from using legacy SH bases as well
        # as a warning from multi_shell_fiber_response
        npt.assert_(len(w) > 1)
        # The last warning in list is the one from multi_shell_fiber_response
        npt.assert_(issubclass(w[-1].category, UserWarning))
        npt.assert_("""No b0 given. Proceeding either way.""" in
                    str(w[-1].message))
        npt.assert_equal(response.response.shape, (3, 7))
예제 #2
0
파일: test_mcsd.py 프로젝트: mvgolub/dipy
def test_mcsd_model_delta():
    sh_order = 8
    gtab = get_3shell_gtab()
    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                          wm_response, gm_response,
                                          csf_response)
    model = MultiShellDeconvModel(gtab, response)
    iso = response.iso

    theta, phi = default_sphere.theta, default_sphere.phi
    B = shm.real_sph_harm(response.m, response.n, theta[:, None], phi[:, None])

    wm_delta = model.delta.copy()
    # set isotropic components to zero
    wm_delta[:iso] = 0.
    wm_delta = _expand(model.m, iso, wm_delta)

    for i, s in enumerate([0, 1000, 2000, 3500]):
        g = GradientTable(default_sphere.vertices * s)
        signal = model.predict(wm_delta, g)
        expected = np.dot(response.response[i, iso:], B.T)
        npt.assert_array_almost_equal(signal, expected)

    signal = model.predict(wm_delta, gtab)
    fit = model.fit(signal)
    m = model.m
    npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)
예제 #3
0
파일: test_mcsd.py 프로젝트: mvgolub/dipy
def test_MSDeconvFit():
    gtab = get_3shell_gtab()

    mevals = np.array([wm_response[0, :3], wm_response[0, :3]])
    angles = [(0, 0), (60, 0)]

    S_wm, sticks = multi_tensor(gtab,
                                mevals,
                                wm_response[0, 3],
                                angles=angles,
                                fractions=[30., 70.],
                                snr=None)
    S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0])
    S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0])

    sh_order = 8
    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                          wm_response, gm_response,
                                          csf_response)
    model = MultiShellDeconvModel(gtab, response)
    vf = [0.325, 0.2, 0.475]
    signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm]))
    fit = model.fit(signal)

    # Testing volume fractions
    npt.assert_array_almost_equal(fit.volume_fractions, vf, 1)
예제 #4
0
def test_MSDeconvFit():
    gtab = get_3shell_gtab()

    mevals = np.array([wm_response[0, :3], wm_response[0, :3]])
    angles = [(0, 0), (60, 0)]

    S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles,
                                fractions=[30., 70.], snr=None)
    S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0])
    S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0])

    sh_order = 8
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                              wm_response,
                                              gm_response,
                                              csf_response)
        model = MultiShellDeconvModel(gtab, response)
    vf = [0.325, 0.2, 0.475]
    signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm]))
    fit = model.fit(signal)

    # Testing volume fractions
    npt.assert_array_almost_equal(fit.volume_fractions, vf, 1)
예제 #5
0
def test_MultiShellDeconvModel_response():
    gtab = get_3shell_gtab()

    sh_order = 8
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                              wm_response,
                                              gm_response,
                                              csf_response)
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order)
    responses = np.array([wm_response, gm_response, csf_response])
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order)
    response_1 = model_1.response.response
    response_2 = model_2.response.response
    npt.assert_array_almost_equal(response_1, response_2, 0)

    npt.assert_raises(ValueError, MultiShellDeconvModel,
                      gtab, np.ones((4, 3, 4)))
    npt.assert_raises(ValueError, MultiShellDeconvModel,
                      gtab, np.ones((3, 3, 4)), iso=3)
예제 #6
0
def test_multi_shell_fiber_response():
    gtab = get_3shell_gtab()
    sh_order = 8
    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                          wm_response, gm_response,
                                          csf_response)

    npt.assert_equal(response.response.shape, (4, 7))

    with warnings.catch_warnings(record=True) as w:
        response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500],
                                              wm_response, gm_response,
                                              csf_response)
        npt.assert_(issubclass(w[0].category, UserWarning))
        npt.assert_(
            """No b0 given. Proceeding either way.""" in str(w[0].message))
        npt.assert_equal(response.response.shape, (3, 7))
예제 #7
0
파일: test_mcsd.py 프로젝트: mvgolub/dipy
def test_multi_shell_fiber_response():
    gtab = get_3shell_gtab()
    sh_order = 8
    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                          wm_response, gm_response,
                                          csf_response)

    npt.assert_equal(response.response.shape, (4, 7))

    with warnings.catch_warnings(record=True) as w:
        response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500],
                                              wm_response, gm_response,
                                              csf_response)
        # Test that the number of warnings raised is greater than 1, with
        # deprecation warnings being raised from using legacy SH bases as well
        # as a warning from multi_shell_fiber_response
        npt.assert_(len(w) > 1)
        # The last warning in list is the one from multi_shell_fiber_response
        npt.assert_(issubclass(w[-1].category, UserWarning))
        npt.assert_(
            """No b0 given. Proceeding either way.""" in str(w[-1].message))
        npt.assert_equal(response.response.shape, (3, 7))
예제 #8
0
def test_mcsd_model_delta():
    sh_order = 8
    gtab = get_3shell_gtab()
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                              wm_response,
                                              gm_response,
                                              csf_response)
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        model = MultiShellDeconvModel(gtab, response)
    iso = response.iso

    theta, phi = default_sphere.theta, default_sphere.phi
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        B = shm.real_sh_descoteaux_from_index(
            response.m, response.n, theta[:, None], phi[:, None])

    wm_delta = model.delta.copy()
    # set isotropic components to zero
    wm_delta[:iso] = 0.
    wm_delta = _expand(model.m, iso, wm_delta)

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        for i, s in enumerate([0, 1000, 2000, 3500]):
            g = GradientTable(default_sphere.vertices * s)
            signal = model.predict(wm_delta, g)
            expected = np.dot(response.response[i, iso:], B.T)
            npt.assert_array_almost_equal(signal, expected)

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=shm.descoteaux07_legacy_msg,
            category=PendingDeprecationWarning)
        signal = model.predict(wm_delta, gtab)
    fit = model.fit(signal)
    m = model.m
    npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)
예제 #9
0
파일: test_mcsd.py 프로젝트: kerkelae/dipy
def test_MultiShellDeconvModel_response():
    gtab = get_3shell_gtab()

    sh_order = 8
    response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500],
                                          wm_response,
                                          gm_response,
                                          csf_response)
    model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order)
    responses = np.array([wm_response, gm_response, csf_response])
    model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order)
    response_1 = model_1.response.response
    response_2 = model_2.response.response
    npt.assert_array_almost_equal(response_1, response_2, 0)

    npt.assert_raises(ValueError, MultiShellDeconvModel,
                      gtab, np.ones((4, 3, 4)))
    npt.assert_raises(ValueError, MultiShellDeconvModel,
                      gtab, np.ones((3, 3, 4)), iso=3)
예제 #10
0
def create_mcsd_model(folder_name, data, gtab, labels, sh_order=8):
    from dipy.reconst.mcsd import response_from_mask_msmt
    from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit
    from dipy.core.gradients import unique_bvals_tolerance

    bvals = gtab.bvals
    wm = labels == 3
    gm = labels == 2
    csf = labels == 1

    mask_wm = wm.astype(float)
    mask_gm = gm.astype(float)
    mask_csf = csf.astype(float)

    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf)

    ubvals = unique_bvals_tolerance(bvals)
    response_mcsd = multi_shell_fiber_response(sh_order,
                                               bvals=ubvals,
                                               wm_rf=response_wm,
                                               csf_rf=response_csf,
                                               gm_rf=response_gm)
    mcsd_model = MultiShellDeconvModel(gtab, response_mcsd)

    mcsd_fit = mcsd_model.fit(data)
    sh_coeff = mcsd_fit.all_shm_coeff
    nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
    coeff = mcsd_fit.all_shm_coeff
    n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
    if nan_count > 0:
        print(
            f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
        )
    coeff = np.where(np.isnan(coeff), 0, coeff)
    mcsd_fit = MSDeconvFit(mcsd_model, coeff, None)
    np.save(folder_name + r'\coeff.npy', coeff)

    return mcsd_fit
예제 #11
0
    def _msmt_ft(self):
        from dipy.reconst.mcsd import response_from_mask_msmt
        from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit
        from dipy.core.gradients import unique_bvals_tolerance

        bvals = self.gtab.bvals
        wm = self.tissue_labels == 2
        gm = self.tissue_labels == 1
        csf = self.tissue_labels == 3

        mask_wm = wm.astype(float)
        mask_gm = gm.astype(float)
        mask_csf = csf.astype(float)

        response_wm, response_gm, response_csf = response_from_mask_msmt(
            self.gtab, self.data, mask_wm, mask_gm, mask_csf)

        ubvals = unique_bvals_tolerance(bvals)
        response_mcsd = multi_shell_fiber_response(
            self.parameters_dict['sh_order'],
            bvals=ubvals,
            wm_rf=response_wm,
            csf_rf=response_csf,
            gm_rf=response_gm)
        mcsd_model = MultiShellDeconvModel(self.gtab, response_mcsd)

        mcsd_fit = mcsd_model.fit(self.data)
        sh_coeff = mcsd_fit.all_shm_coeff
        nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
        coeff = mcsd_fit.all_shm_coeff
        n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
        if nan_count > 0:
            print(
                f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
            )
        coeff = np.where(np.isnan(coeff), 0, coeff)
        mcsd_fit = MSDeconvFit(mcsd_model, coeff, None)
        self.model_fit = mcsd_fit
예제 #12
0
파일: msmt_example.py 프로젝트: HilaGast/FT
t_mask_img = load_nifti(tissue_mask)[0]
wm = t_mask_img == 3
gm = t_mask_img == 2
csf = t_mask_img == 1

mask_wm = wm.astype(float)
mask_gm = gm.astype(float)
mask_csf = csf.astype(float)

response_wm, response_gm, response_csf = response_from_mask_msmt(
    gtab, data, mask_wm, mask_gm, mask_csf)

ubvals = unique_bvals_tolerance(bvals)
response_mcsd = multi_shell_fiber_response(sh_order=8,
                                           bvals=ubvals,
                                           wm_rf=response_wm,
                                           csf_rf=response_csf,
                                           gm_rf=response_gm)

mcsd_model = MultiShellDeconvModel(gtab, response_mcsd)
mcsd_fit = mcsd_model.fit(denoised_arr)
sh_coeff = mcsd_fit.all_shm_coeff
nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0])))
coeff = mcsd_fit.all_shm_coeff
n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2]
print(
    f'{nan_count/n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0'
)
coeff = np.where(np.isnan(coeff), 0, coeff)
mcsd_fit = MSDeconvFit(mcsd_model, coeff, None)
np.save(folder_name + r'\coeff.npy', coeff)
예제 #13
0
파일: estimation.py 프로젝트: dPys/PyNets
def mcsd_mod_est(gtab,
                 data,
                 B0_mask,
                 wm_in_dwi,
                 gm_in_dwi,
                 vent_csf_in_dwi,
                 sh_order=8,
                 roi_radii=10):
    """
    Estimate a Constrained Spherical Deconvolution (CSD) model from dwi data.

    Parameters
    ----------
    gtab : Obj
        DiPy object storing diffusion gradient information.
    data : array
        4D numpy array of diffusion image data.
    B0_mask : str
        File path to B0 brain mask.
    sh_order : int
        The order of the SH model. Default is 8.

    Returns
    -------
    csd_mod : ndarray
        Coefficients of the csd reconstruction.
    model : obj
        Fitted csd model.

    References
    ----------
    .. [1] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
      the fibre orientation distribution in diffusion MRI:
      Non-negativity constrained super-resolved spherical
      deconvolution
    .. [2] Descoteaux, M., et al. IEEE TMI 2009. Deterministic and
      Probabilistic Tractography Based on Complex Fibre Orientation
      Distributions
    .. [3] Côté, M-A., et al. Medical Image Analysis 2013. Tractometer:
      Towards validation of tractography pipelines
    .. [4] Tournier, J.D, et al. Imaging Systems and Technology
      2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions

    """
    import dipy.reconst.dti as dti
    from nilearn.image import math_img
    from dipy.core.gradients import unique_bvals_tolerance
    from dipy.reconst.mcsd import (mask_for_response_msmt,
                                   response_from_mask_msmt,
                                   multi_shell_fiber_response,
                                   MultiShellDeconvModel)

    print("Reconstructing using MCSD...")

    B0_mask_data = np.nan_to_num(np.asarray(
        nib.load(B0_mask).dataobj)).astype("bool")

    # Load tissue maps and prepare tissue classifier
    gm_mask_img = math_img("img > 0.10", img=gm_in_dwi)
    gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32)

    wm_mask_img = math_img("img > 0.15", img=wm_in_dwi)
    wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32)

    vent_csf_in_dwi_img = math_img("img > 0.50", img=vent_csf_in_dwi)
    vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj,
                                      dtype=np.float32)

    # Fit a simple DTI model
    tenfit = dti.TensorModel(gtab).fit(data)

    # Obtain the FA and MD metrics
    FA = tenfit.fa
    MD = tenfit.md

    indices_csf = np.where(((FA < 0.2) & (vent_csf_in_dwi_data > 0.50)))
    indices_gm = np.where(((FA < 0.2) & (gm_data > 0.10)))
    indices_wm = np.where(((FA >= 0.2) & (wm_data > 0.15)))

    selected_csf = np.zeros(FA.shape, dtype='bool')
    selected_gm = np.zeros(FA.shape, dtype='bool')
    selected_wm = np.zeros(FA.shape, dtype='bool')

    selected_csf[indices_csf] = True
    selected_gm[indices_gm] = True
    selected_wm[indices_wm] = True

    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
        gtab,
        data,
        roi_radii=roi_radii,
        wm_fa_thr=np.nanmean(FA[selected_wm]),
        gm_fa_thr=np.nanmean(FA[selected_gm]),
        csf_fa_thr=np.nanmean(FA[selected_csf]),
        gm_md_thr=np.nanmean(MD[selected_gm]),
        csf_md_thr=np.nanmean(MD[selected_csf]))

    mask_wm *= wm_data.astype('int64')
    mask_gm *= gm_data.astype('int64')
    mask_csf *= vent_csf_in_dwi_data.astype('int64')

    # nvoxels_wm = np.sum(mask_wm)
    # nvoxels_gm = np.sum(mask_gm)
    # nvoxels_csf = np.sum(mask_csf)

    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf)

    response_mcsd = multi_shell_fiber_response(sh_order=8,
                                               bvals=unique_bvals_tolerance(
                                                   gtab.bvals),
                                               wm_rf=response_wm,
                                               gm_rf=response_gm,
                                               csf_rf=response_csf)

    model = MultiShellDeconvModel(gtab, response_mcsd, sh_order=sh_order)
    mcsd_mod = model.fit(data, B0_mask_data).shm_coeff

    mcsd_mod = np.clip(mcsd_mod, 0, np.max(mcsd_mod, -1)[..., None])
    del response_mcsd, B0_mask_data
    return mcsd_mod.astype("float32"), model
예제 #14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.wm_out_fODF = args.wm_out_fODF or 'wm_fodf.nii.gz'
        args.gm_out_fODF = args.gm_out_fODF or 'gm_fodf.nii.gz'
        args.csf_out_fODF = args.csf_out_fODF or 'csf_fodf.nii.gz'
        args.vf = args.vf or 'vf.nii.gz'
        args.vf_rgb = args.vf_rgb or 'vf_rgb.nii.gz'

    arglist = [args.wm_out_fODF, args.gm_out_fODF, args.csf_out_fODF,
               args.vf, args.vf_rgb]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec,
                                 args.in_wm_frf, args.in_gm_frf,
                                 args.in_csf_frf])
    assert_outputs_exist(parser, args, arglist)

    # Loading data
    wm_frf = np.loadtxt(args.in_wm_frf)
    gm_frf = np.loadtxt(args.in_gm_frf)
    csf_frf = np.loadtxt(args.in_csf_frf)
    vol = nib.load(args.in_dwi)
    data = vol.get_fdata(dtype=np.float32)
    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    # Checking mask
    if args.mask is None:
        mask = None
    else:
        mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
        if mask.shape != data.shape[:-1]:
            raise ValueError("Mask is not the same shape as data.")

    sh_order = args.sh_order

    # Checking data and sh_order
    b0_thr = check_b0_threshold(
        args.force_b0_threshold, bvals.min(), bvals.min())
    if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2:
        logging.warning(
            'We recommend having at least {} unique DWIs volumes, but you '
            'currently have {} volumes. Try lowering the parameter --sh_order '
            'in case of non convergence.'.format(
                (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))

    # Checking bvals, bvecs values and loading gtab
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)
    gtab = gradient_table(bvals, bvecs, b0_threshold=b0_thr)

    # Checking response functions and computing msmt response function
    if not wm_frf.shape[1] == 4:
        raise ValueError('WM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not gm_frf.shape[1] == 4:
        raise ValueError('GM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not csf_frf.shape[1] == 4:
        raise ValueError('CSF frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    ubvals = unique_bvals_tolerance(bvals, tol=20)
    msmt_response = multi_shell_fiber_response(sh_order, ubvals,
                                               wm_frf, gm_frf, csf_frf)

    # Loading spheres
    reg_sphere = get_sphere('symmetric362')

    # Computing msmt-CSD
    msmt_model = MultiShellDeconvModel(gtab, msmt_response,
                                       reg_sphere=reg_sphere,
                                       sh_order=sh_order)

    # Computing msmt-CSD fit
    msmt_fit = fit_from_model(msmt_model, data,
                              mask=mask, nbr_processes=args.nbr_processes)

    shm_coeff = msmt_fit.all_shm_coeff

    nan_count = len(np.argwhere(np.isnan(shm_coeff[..., 0])))
    voxel_count = np.prod(shm_coeff.shape[:-1])

    if nan_count / voxel_count >= 0.05:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver, reaching a critical amount of voxels. Make sure to tune the
        response functions properly, as the solving process is very sensitive
        to it. Proceeding to fill the problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))
    elif nan_count > 0:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver. Make sure to tune the response functions properly, as the
        solving process is very sensitive to it. Proceeding to fill the
        problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))

    shm_coeff = np.where(np.isnan(shm_coeff), 0, shm_coeff)

    # Saving results
    if args.wm_out_fODF:
        wm_coeff = shm_coeff[..., 2:]
        if args.sh_basis == 'tournier07':
            wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32),
                                 vol.affine), args.wm_out_fODF)

    if args.gm_out_fODF:
        gm_coeff = shm_coeff[..., 1]
        if args.sh_basis == 'tournier07':
            gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,))
            gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32),
                                 vol.affine), args.gm_out_fODF)

    if args.csf_out_fODF:
        csf_coeff = shm_coeff[..., 0]
        if args.sh_basis == 'tournier07':
            csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,))
            csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask,
                                         nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32),
                                 vol.affine), args.csf_out_fODF)

    if args.vf:
        nib.save(nib.Nifti1Image(msmt_fit.volume_fractions.astype(np.float32),
                                 vol.affine), args.vf)

    if args.vf_rgb:
        vf = msmt_fit.volume_fractions
        vf_rgb = vf / np.max(vf) * 255
        vf_rgb = np.clip(vf_rgb, 0, 255)
        nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8),
                                 vol.affine), args.vf_rgb)