def test_multi_shell_fiber_response(): sh_order = 8 with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) npt.assert_equal(response.response.shape, (4, 7)) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", category=PendingDeprecationWarning) response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500], wm_response, gm_response, csf_response) # Test that the number of warnings raised is greater than 1, with # deprecation warnings being raised from using legacy SH bases as well # as a warning from multi_shell_fiber_response npt.assert_(len(w) > 1) # The last warning in list is the one from multi_shell_fiber_response npt.assert_(issubclass(w[-1].category, UserWarning)) npt.assert_("""No b0 given. Proceeding either way.""" in str(w[-1].message)) npt.assert_equal(response.response.shape, (3, 7))
def test_mcsd_model_delta(): sh_order = 8 gtab = get_3shell_gtab() response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) model = MultiShellDeconvModel(gtab, response) iso = response.iso theta, phi = default_sphere.theta, default_sphere.phi B = shm.real_sph_harm(response.m, response.n, theta[:, None], phi[:, None]) wm_delta = model.delta.copy() # set isotropic components to zero wm_delta[:iso] = 0. wm_delta = _expand(model.m, iso, wm_delta) for i, s in enumerate([0, 1000, 2000, 3500]): g = GradientTable(default_sphere.vertices * s) signal = model.predict(wm_delta, g) expected = np.dot(response.response[i, iso:], B.T) npt.assert_array_almost_equal(signal, expected) signal = model.predict(wm_delta, gtab) fit = model.fit(signal) m = model.m npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)
def test_MSDeconvFit(): gtab = get_3shell_gtab() mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) angles = [(0, 0), (60, 0)] S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles, fractions=[30., 70.], snr=None) S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) sh_order = 8 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) model = MultiShellDeconvModel(gtab, response) vf = [0.325, 0.2, 0.475] signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) fit = model.fit(signal) # Testing volume fractions npt.assert_array_almost_equal(fit.volume_fractions, vf, 1)
def test_MSDeconvFit(): gtab = get_3shell_gtab() mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) angles = [(0, 0), (60, 0)] S_wm, sticks = multi_tensor(gtab, mevals, wm_response[0, 3], angles=angles, fractions=[30., 70.], snr=None) S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) sh_order = 8 with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) model = MultiShellDeconvModel(gtab, response) vf = [0.325, 0.2, 0.475] signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) fit = model.fit(signal) # Testing volume fractions npt.assert_array_almost_equal(fit.volume_fractions, vf, 1)
def test_MultiShellDeconvModel_response(): gtab = get_3shell_gtab() sh_order = 8 with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order) responses = np.array([wm_response, gm_response, csf_response]) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order) response_1 = model_1.response.response response_2 = model_2.response.response npt.assert_array_almost_equal(response_1, response_2, 0) npt.assert_raises(ValueError, MultiShellDeconvModel, gtab, np.ones((4, 3, 4))) npt.assert_raises(ValueError, MultiShellDeconvModel, gtab, np.ones((3, 3, 4)), iso=3)
def test_multi_shell_fiber_response(): gtab = get_3shell_gtab() sh_order = 8 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) npt.assert_equal(response.response.shape, (4, 7)) with warnings.catch_warnings(record=True) as w: response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500], wm_response, gm_response, csf_response) npt.assert_(issubclass(w[0].category, UserWarning)) npt.assert_( """No b0 given. Proceeding either way.""" in str(w[0].message)) npt.assert_equal(response.response.shape, (3, 7))
def test_multi_shell_fiber_response(): gtab = get_3shell_gtab() sh_order = 8 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) npt.assert_equal(response.response.shape, (4, 7)) with warnings.catch_warnings(record=True) as w: response = multi_shell_fiber_response(sh_order, [1000, 2000, 3500], wm_response, gm_response, csf_response) # Test that the number of warnings raised is greater than 1, with # deprecation warnings being raised from using legacy SH bases as well # as a warning from multi_shell_fiber_response npt.assert_(len(w) > 1) # The last warning in list is the one from multi_shell_fiber_response npt.assert_(issubclass(w[-1].category, UserWarning)) npt.assert_( """No b0 given. Proceeding either way.""" in str(w[-1].message)) npt.assert_equal(response.response.shape, (3, 7))
def test_mcsd_model_delta(): sh_order = 8 gtab = get_3shell_gtab() with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) model = MultiShellDeconvModel(gtab, response) iso = response.iso theta, phi = default_sphere.theta, default_sphere.phi with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) B = shm.real_sh_descoteaux_from_index( response.m, response.n, theta[:, None], phi[:, None]) wm_delta = model.delta.copy() # set isotropic components to zero wm_delta[:iso] = 0. wm_delta = _expand(model.m, iso, wm_delta) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) for i, s in enumerate([0, 1000, 2000, 3500]): g = GradientTable(default_sphere.vertices * s) signal = model.predict(wm_delta, g) expected = np.dot(response.response[i, iso:], B.T) npt.assert_array_almost_equal(signal, expected) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=shm.descoteaux07_legacy_msg, category=PendingDeprecationWarning) signal = model.predict(wm_delta, gtab) fit = model.fit(signal) m = model.m npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)
def test_MultiShellDeconvModel_response(): gtab = get_3shell_gtab() sh_order = 8 response = multi_shell_fiber_response(sh_order, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response) model_1 = MultiShellDeconvModel(gtab, response, sh_order=sh_order) responses = np.array([wm_response, gm_response, csf_response]) model_2 = MultiShellDeconvModel(gtab, responses, sh_order=sh_order) response_1 = model_1.response.response response_2 = model_2.response.response npt.assert_array_almost_equal(response_1, response_2, 0) npt.assert_raises(ValueError, MultiShellDeconvModel, gtab, np.ones((4, 3, 4))) npt.assert_raises(ValueError, MultiShellDeconvModel, gtab, np.ones((3, 3, 4)), iso=3)
def create_mcsd_model(folder_name, data, gtab, labels, sh_order=8): from dipy.reconst.mcsd import response_from_mask_msmt from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit from dipy.core.gradients import unique_bvals_tolerance bvals = gtab.bvals wm = labels == 3 gm = labels == 2 csf = labels == 1 mask_wm = wm.astype(float) mask_gm = gm.astype(float) mask_csf = csf.astype(float) response_wm, response_gm, response_csf = response_from_mask_msmt( gtab, data, mask_wm, mask_gm, mask_csf) ubvals = unique_bvals_tolerance(bvals) response_mcsd = multi_shell_fiber_response(sh_order, bvals=ubvals, wm_rf=response_wm, csf_rf=response_csf, gm_rf=response_gm) mcsd_model = MultiShellDeconvModel(gtab, response_mcsd) mcsd_fit = mcsd_model.fit(data) sh_coeff = mcsd_fit.all_shm_coeff nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0]))) coeff = mcsd_fit.all_shm_coeff n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2] if nan_count > 0: print( f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0' ) coeff = np.where(np.isnan(coeff), 0, coeff) mcsd_fit = MSDeconvFit(mcsd_model, coeff, None) np.save(folder_name + r'\coeff.npy', coeff) return mcsd_fit
def _msmt_ft(self): from dipy.reconst.mcsd import response_from_mask_msmt from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response, MSDeconvFit from dipy.core.gradients import unique_bvals_tolerance bvals = self.gtab.bvals wm = self.tissue_labels == 2 gm = self.tissue_labels == 1 csf = self.tissue_labels == 3 mask_wm = wm.astype(float) mask_gm = gm.astype(float) mask_csf = csf.astype(float) response_wm, response_gm, response_csf = response_from_mask_msmt( self.gtab, self.data, mask_wm, mask_gm, mask_csf) ubvals = unique_bvals_tolerance(bvals) response_mcsd = multi_shell_fiber_response( self.parameters_dict['sh_order'], bvals=ubvals, wm_rf=response_wm, csf_rf=response_csf, gm_rf=response_gm) mcsd_model = MultiShellDeconvModel(self.gtab, response_mcsd) mcsd_fit = mcsd_model.fit(self.data) sh_coeff = mcsd_fit.all_shm_coeff nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0]))) coeff = mcsd_fit.all_shm_coeff n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2] if nan_count > 0: print( f'{nan_count / n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0' ) coeff = np.where(np.isnan(coeff), 0, coeff) mcsd_fit = MSDeconvFit(mcsd_model, coeff, None) self.model_fit = mcsd_fit
t_mask_img = load_nifti(tissue_mask)[0] wm = t_mask_img == 3 gm = t_mask_img == 2 csf = t_mask_img == 1 mask_wm = wm.astype(float) mask_gm = gm.astype(float) mask_csf = csf.astype(float) response_wm, response_gm, response_csf = response_from_mask_msmt( gtab, data, mask_wm, mask_gm, mask_csf) ubvals = unique_bvals_tolerance(bvals) response_mcsd = multi_shell_fiber_response(sh_order=8, bvals=ubvals, wm_rf=response_wm, csf_rf=response_csf, gm_rf=response_gm) mcsd_model = MultiShellDeconvModel(gtab, response_mcsd) mcsd_fit = mcsd_model.fit(denoised_arr) sh_coeff = mcsd_fit.all_shm_coeff nan_count = len(np.argwhere(np.isnan(sh_coeff[..., 0]))) coeff = mcsd_fit.all_shm_coeff n_vox = coeff.shape[0] * coeff.shape[1] * coeff.shape[2] print( f'{nan_count/n_vox} of the voxels did not complete fodf calculation, NaN values replaced with 0' ) coeff = np.where(np.isnan(coeff), 0, coeff) mcsd_fit = MSDeconvFit(mcsd_model, coeff, None) np.save(folder_name + r'\coeff.npy', coeff)
def mcsd_mod_est(gtab, data, B0_mask, wm_in_dwi, gm_in_dwi, vent_csf_in_dwi, sh_order=8, roi_radii=10): """ Estimate a Constrained Spherical Deconvolution (CSD) model from dwi data. Parameters ---------- gtab : Obj DiPy object storing diffusion gradient information. data : array 4D numpy array of diffusion image data. B0_mask : str File path to B0 brain mask. sh_order : int The order of the SH model. Default is 8. Returns ------- csd_mod : ndarray Coefficients of the csd reconstruction. model : obj Fitted csd model. References ---------- .. [1] Tournier, J.D., et al. NeuroImage 2007. Robust determination of the fibre orientation distribution in diffusion MRI: Non-negativity constrained super-resolved spherical deconvolution .. [2] Descoteaux, M., et al. IEEE TMI 2009. Deterministic and Probabilistic Tractography Based on Complex Fibre Orientation Distributions .. [3] Côté, M-A., et al. Medical Image Analysis 2013. Tractometer: Towards validation of tractography pipelines .. [4] Tournier, J.D, et al. Imaging Systems and Technology 2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions """ import dipy.reconst.dti as dti from nilearn.image import math_img from dipy.core.gradients import unique_bvals_tolerance from dipy.reconst.mcsd import (mask_for_response_msmt, response_from_mask_msmt, multi_shell_fiber_response, MultiShellDeconvModel) print("Reconstructing using MCSD...") B0_mask_data = np.nan_to_num(np.asarray( nib.load(B0_mask).dataobj)).astype("bool") # Load tissue maps and prepare tissue classifier gm_mask_img = math_img("img > 0.10", img=gm_in_dwi) gm_data = np.asarray(gm_mask_img.dataobj, dtype=np.float32) wm_mask_img = math_img("img > 0.15", img=wm_in_dwi) wm_data = np.asarray(wm_mask_img.dataobj, dtype=np.float32) vent_csf_in_dwi_img = math_img("img > 0.50", img=vent_csf_in_dwi) vent_csf_in_dwi_data = np.asarray(vent_csf_in_dwi_img.dataobj, dtype=np.float32) # Fit a simple DTI model tenfit = dti.TensorModel(gtab).fit(data) # Obtain the FA and MD metrics FA = tenfit.fa MD = tenfit.md indices_csf = np.where(((FA < 0.2) & (vent_csf_in_dwi_data > 0.50))) indices_gm = np.where(((FA < 0.2) & (gm_data > 0.10))) indices_wm = np.where(((FA >= 0.2) & (wm_data > 0.15))) selected_csf = np.zeros(FA.shape, dtype='bool') selected_gm = np.zeros(FA.shape, dtype='bool') selected_wm = np.zeros(FA.shape, dtype='bool') selected_csf[indices_csf] = True selected_gm[indices_gm] = True selected_wm[indices_wm] = True mask_wm, mask_gm, mask_csf = mask_for_response_msmt( gtab, data, roi_radii=roi_radii, wm_fa_thr=np.nanmean(FA[selected_wm]), gm_fa_thr=np.nanmean(FA[selected_gm]), csf_fa_thr=np.nanmean(FA[selected_csf]), gm_md_thr=np.nanmean(MD[selected_gm]), csf_md_thr=np.nanmean(MD[selected_csf])) mask_wm *= wm_data.astype('int64') mask_gm *= gm_data.astype('int64') mask_csf *= vent_csf_in_dwi_data.astype('int64') # nvoxels_wm = np.sum(mask_wm) # nvoxels_gm = np.sum(mask_gm) # nvoxels_csf = np.sum(mask_csf) response_wm, response_gm, response_csf = response_from_mask_msmt( gtab, data, mask_wm, mask_gm, mask_csf) response_mcsd = multi_shell_fiber_response(sh_order=8, bvals=unique_bvals_tolerance( gtab.bvals), wm_rf=response_wm, gm_rf=response_gm, csf_rf=response_csf) model = MultiShellDeconvModel(gtab, response_mcsd, sh_order=sh_order) mcsd_mod = model.fit(data, B0_mask_data).shm_coeff mcsd_mod = np.clip(mcsd_mod, 0, np.max(mcsd_mod, -1)[..., None]) del response_mcsd, B0_mask_data return mcsd_mod.astype("float32"), model
def main(): parser = _build_arg_parser() args = parser.parse_args() logging.basicConfig(level=logging.INFO) if not args.not_all: args.wm_out_fODF = args.wm_out_fODF or 'wm_fodf.nii.gz' args.gm_out_fODF = args.gm_out_fODF or 'gm_fodf.nii.gz' args.csf_out_fODF = args.csf_out_fODF or 'csf_fodf.nii.gz' args.vf = args.vf or 'vf.nii.gz' args.vf_rgb = args.vf_rgb or 'vf_rgb.nii.gz' arglist = [args.wm_out_fODF, args.gm_out_fODF, args.csf_out_fODF, args.vf, args.vf_rgb] if args.not_all and not any(arglist): parser.error('When using --not_all, you need to specify at least ' + 'one file to output.') assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec, args.in_wm_frf, args.in_gm_frf, args.in_csf_frf]) assert_outputs_exist(parser, args, arglist) # Loading data wm_frf = np.loadtxt(args.in_wm_frf) gm_frf = np.loadtxt(args.in_gm_frf) csf_frf = np.loadtxt(args.in_csf_frf) vol = nib.load(args.in_dwi) data = vol.get_fdata(dtype=np.float32) bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) # Checking mask if args.mask is None: mask = None else: mask = get_data_as_mask(nib.load(args.mask), dtype=bool) if mask.shape != data.shape[:-1]: raise ValueError("Mask is not the same shape as data.") sh_order = args.sh_order # Checking data and sh_order b0_thr = check_b0_threshold( args.force_b0_threshold, bvals.min(), bvals.min()) if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: logging.warning( 'We recommend having at least {} unique DWIs volumes, but you ' 'currently have {} volumes. Try lowering the parameter --sh_order ' 'in case of non convergence.'.format( (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1])) # Checking bvals, bvecs values and loading gtab if not is_normalized_bvecs(bvecs): logging.warning('Your b-vectors do not seem normalized...') bvecs = normalize_bvecs(bvecs) gtab = gradient_table(bvals, bvecs, b0_threshold=b0_thr) # Checking response functions and computing msmt response function if not wm_frf.shape[1] == 4: raise ValueError('WM frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') if not gm_frf.shape[1] == 4: raise ValueError('GM frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') if not csf_frf.shape[1] == 4: raise ValueError('CSF frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') ubvals = unique_bvals_tolerance(bvals, tol=20) msmt_response = multi_shell_fiber_response(sh_order, ubvals, wm_frf, gm_frf, csf_frf) # Loading spheres reg_sphere = get_sphere('symmetric362') # Computing msmt-CSD msmt_model = MultiShellDeconvModel(gtab, msmt_response, reg_sphere=reg_sphere, sh_order=sh_order) # Computing msmt-CSD fit msmt_fit = fit_from_model(msmt_model, data, mask=mask, nbr_processes=args.nbr_processes) shm_coeff = msmt_fit.all_shm_coeff nan_count = len(np.argwhere(np.isnan(shm_coeff[..., 0]))) voxel_count = np.prod(shm_coeff.shape[:-1]) if nan_count / voxel_count >= 0.05: msg = """There are {} voxels out of {} that could not be solved by the solver, reaching a critical amount of voxels. Make sure to tune the response functions properly, as the solving process is very sensitive to it. Proceeding to fill the problematic voxels by 0. """ logging.warning(msg.format(nan_count, voxel_count)) elif nan_count > 0: msg = """There are {} voxels out of {} that could not be solved by the solver. Make sure to tune the response functions properly, as the solving process is very sensitive to it. Proceeding to fill the problematic voxels by 0. """ logging.warning(msg.format(nan_count, voxel_count)) shm_coeff = np.where(np.isnan(shm_coeff), 0, shm_coeff) # Saving results if args.wm_out_fODF: wm_coeff = shm_coeff[..., 2:] if args.sh_basis == 'tournier07': wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask, nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32), vol.affine), args.wm_out_fODF) if args.gm_out_fODF: gm_coeff = shm_coeff[..., 1] if args.sh_basis == 'tournier07': gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,)) gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask, nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32), vol.affine), args.gm_out_fODF) if args.csf_out_fODF: csf_coeff = shm_coeff[..., 0] if args.sh_basis == 'tournier07': csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,)) csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask, nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32), vol.affine), args.csf_out_fODF) if args.vf: nib.save(nib.Nifti1Image(msmt_fit.volume_fractions.astype(np.float32), vol.affine), args.vf) if args.vf_rgb: vf = msmt_fit.volume_fractions vf_rgb = vf / np.max(vf) * 255 vf_rgb = np.clip(vf_rgb, 0, 255) nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8), vol.affine), args.vf_rgb)