Exemplo n.º 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bvecs])
    assert_outputs_exists(parser, args, [args.normalized_bvecs])

    _, bvecs = read_bvals_bvecs(None, args.bvecs)

    if is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors are already normalized')

    normalize_bvecs(bvecs, args.normalized_bvecs)
def main():
    logger = logging.getLogger("Compute_DKI_Metrics")
    logger.setLevel(logging.INFO)

    parser = _build_args_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.dki_fa = args.dki_fa or 'dki_fa.nii.gz'
        args.dki_md = args.dki_md or 'dki_md.nii.gz'
        args.dki_ad = args.dki_ad or 'dki_ad.nii.gz'
        args.dki_rd = args.dki_rd or 'dki_rd.nii.gz'
        args.mk = args.mk or 'mk.nii.gz'
        args.rk = args.rk or 'rk.nii.gz'
        args.ak = args.ak or 'ak.nii.gz'
        args.dki_residual = args.dki_residual or 'dki_residual.nii.gz'
        args.msk = args.msk or 'msk.nii.gz'
        args.msd = args.msd or 'msd.nii.gz'

    outputs = [args.dki_fa, args.dki_md, args.dki_ad, args.dki_rd,
               args.mk, args.rk, args.ak, args.dki_residual,
               args.msk, args.msd]

    if args.not_all and not any(outputs):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one metric to output.')

    assert_inputs_exist(
        parser, [args.input, args.bvals, args.bvecs], args.mask)
    assert_outputs_exist(parser, args, outputs)

    img = nib.load(args.input)
    data = img.get_fdata()
    affine = img.affine
    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_fdata().astype(np.bool)

    # Validate bvals and bvecs
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    # Find the volume indices that correspond to the shells to extract.
    tol = args.tolerance
    shells, _ = identify_shells(bvals, tol)
    if not len(shells) >= 3:
        parser.error('Data is not multi-shell. You need at least 2 non-zero' +
                     ' b-values')

    if (shells > 2500).any():
        logging.warning('You seem to be using b > 2500 s/mm2 DWI data. ' +
                        'In theory, this is beyond the optimal range for DKI')

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    fwhm = args.smooth
    if fwhm > 0:
        # converting fwhm to Gaussian std
        gauss_std = fwhm / np.sqrt(8 * np.log(2))
        data_smooth = np.zeros(data.shape)
        for v in range(data.shape[-1]):
            data_smooth[..., v] = gaussian_filter(data[..., v],
                                                  sigma=gauss_std)
        data = data_smooth

    # Compute DKI
    dkimodel = dki.DiffusionKurtosisModel(gtab)
    dkifit = dkimodel.fit(data, mask=mask)

    min_k = args.min_k
    max_k = args.max_k

    if args.dki_fa:
        FA = dkifit.fa
        FA[np.isnan(FA)] = 0
        FA = np.clip(FA, 0, 1)

        fa_img = nib.Nifti1Image(FA.astype(np.float32), affine)
        nib.save(fa_img, args.dki_fa)

    if args.dki_md:
        MD = dkifit.md
        md_img = nib.Nifti1Image(MD.astype(np.float32), affine)
        nib.save(md_img, args.dki_md)

    if args.dki_ad:
        AD = dkifit.ad
        ad_img = nib.Nifti1Image(AD.astype(np.float32), affine)
        nib.save(ad_img, args.dki_ad)

    if args.dki_rd:
        RD = dkifit.rd
        rd_img = nib.Nifti1Image(RD.astype(np.float32), affine)
        nib.save(rd_img, args.dki_rd)

    if args.mk:
        MK = dkifit.mk(min_k, max_k)
        mk_img = nib.Nifti1Image(MK.astype(np.float32), affine)
        nib.save(mk_img, args.mk)

    if args.ak:
        AK = dkifit.ak(min_k, max_k)
        ak_img = nib.Nifti1Image(AK.astype(np.float32), affine)
        nib.save(ak_img, args.ak)

    if args.rk:
        RK = dkifit.rk(min_k, max_k)
        rk_img = nib.Nifti1Image(RK.astype(np.float32), affine)
        nib.save(rk_img, args.rk)

    if args.msk or args.msd:
        # Compute MSDKI
        msdki_model = msdki.MeanDiffusionKurtosisModel(gtab)
        msdki_fit = msdki_model.fit(data, mask=mask)

        if args.msk:
            MSK = msdki_fit.msk
            MSK[np.isnan(MSK)] = 0
            MSK = np.clip(MSK, min_k, max_k)

            msk_img = nib.Nifti1Image(MSK.astype(np.float32), affine)
            nib.save(msk_img, args.msk)

        if args.msd:
            MSD = msdki_fit.msd
            msd_img = nib.Nifti1Image(MSD.astype(np.float32), affine)
            nib.save(msd_img, args.msd)

    if args.dki_residual:
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1)
        data_p = dkifit.predict(gtab, S0)
        R = np.mean(np.abs(data_p[..., ~gtab.b0s_mask] -
                           data[..., ~gtab.b0s_mask]), axis=-1)

        norm = np.linalg.norm(R)
        if norm != 0:
            R = R / norm

        if args.mask is not None:
            R *= mask

        R_img = nib.Nifti1Image(R.astype(np.float32), affine)
        nib.save(R_img, args.dki_residual)
Exemplo n.º 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.wm_out_fODF = args.wm_out_fODF or 'wm_fodf.nii.gz'
        args.gm_out_fODF = args.gm_out_fODF or 'gm_fodf.nii.gz'
        args.csf_out_fODF = args.csf_out_fODF or 'csf_fodf.nii.gz'
        args.vf = args.vf or 'vf.nii.gz'
        args.vf_rgb = args.vf_rgb or 'vf_rgb.nii.gz'

    arglist = [args.wm_out_fODF, args.gm_out_fODF, args.csf_out_fODF,
               args.vf, args.vf_rgb]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec,
                                 args.in_wm_frf, args.in_gm_frf,
                                 args.in_csf_frf])
    assert_outputs_exist(parser, args, arglist)

    # Loading data
    wm_frf = np.loadtxt(args.in_wm_frf)
    gm_frf = np.loadtxt(args.in_gm_frf)
    csf_frf = np.loadtxt(args.in_csf_frf)
    vol = nib.load(args.in_dwi)
    data = vol.get_fdata(dtype=np.float32)
    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    # Checking mask
    if args.mask is None:
        mask = None
    else:
        mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
        if mask.shape != data.shape[:-1]:
            raise ValueError("Mask is not the same shape as data.")

    sh_order = args.sh_order

    # Checking data and sh_order
    b0_thr = check_b0_threshold(
        args.force_b0_threshold, bvals.min(), bvals.min())
    if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2:
        logging.warning(
            'We recommend having at least {} unique DWIs volumes, but you '
            'currently have {} volumes. Try lowering the parameter --sh_order '
            'in case of non convergence.'.format(
                (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))

    # Checking bvals, bvecs values and loading gtab
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)
    gtab = gradient_table(bvals, bvecs, b0_threshold=b0_thr)

    # Checking response functions and computing msmt response function
    if not wm_frf.shape[1] == 4:
        raise ValueError('WM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not gm_frf.shape[1] == 4:
        raise ValueError('GM frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    if not csf_frf.shape[1] == 4:
        raise ValueError('CSF frf file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    ubvals = unique_bvals_tolerance(bvals, tol=20)
    msmt_response = multi_shell_fiber_response(sh_order, ubvals,
                                               wm_frf, gm_frf, csf_frf)

    # Loading spheres
    reg_sphere = get_sphere('symmetric362')

    # Computing msmt-CSD
    msmt_model = MultiShellDeconvModel(gtab, msmt_response,
                                       reg_sphere=reg_sphere,
                                       sh_order=sh_order)

    # Computing msmt-CSD fit
    msmt_fit = fit_from_model(msmt_model, data,
                              mask=mask, nbr_processes=args.nbr_processes)

    shm_coeff = msmt_fit.all_shm_coeff

    nan_count = len(np.argwhere(np.isnan(shm_coeff[..., 0])))
    voxel_count = np.prod(shm_coeff.shape[:-1])

    if nan_count / voxel_count >= 0.05:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver, reaching a critical amount of voxels. Make sure to tune the
        response functions properly, as the solving process is very sensitive
        to it. Proceeding to fill the problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))
    elif nan_count > 0:
        msg = """There are {} voxels out of {} that could not be solved by
        the solver. Make sure to tune the response functions properly, as the
        solving process is very sensitive to it. Proceeding to fill the
        problematic voxels by 0.
        """
        logging.warning(msg.format(nan_count, voxel_count))

    shm_coeff = np.where(np.isnan(shm_coeff), 0, shm_coeff)

    # Saving results
    if args.wm_out_fODF:
        wm_coeff = shm_coeff[..., 2:]
        if args.sh_basis == 'tournier07':
            wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32),
                                 vol.affine), args.wm_out_fODF)

    if args.gm_out_fODF:
        gm_coeff = shm_coeff[..., 1]
        if args.sh_basis == 'tournier07':
            gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,))
            gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask,
                                        nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32),
                                 vol.affine), args.gm_out_fODF)

    if args.csf_out_fODF:
        csf_coeff = shm_coeff[..., 0]
        if args.sh_basis == 'tournier07':
            csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,))
            csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask,
                                         nbr_processes=args.nbr_processes)
        nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32),
                                 vol.affine), args.csf_out_fODF)

    if args.vf:
        nib.save(nib.Nifti1Image(msmt_fit.volume_fractions.astype(np.float32),
                                 vol.affine), args.vf)

    if args.vf_rgb:
        vf = msmt_fit.volume_fractions
        vf_rgb = vf / np.max(vf) * 255
        vf_rgb = np.clip(vf_rgb, 0, 255)
        nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8),
                                 vol.affine), args.vf_rgb)
Exemplo n.º 4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.fodf = args.fodf or 'fodf.nii.gz'
        args.peaks = args.peaks or 'peaks.nii.gz'
        args.peak_indices = args.peak_indices or 'peak_indices.nii.gz'

    arglist = [args.fodf, args.peaks, args.peak_indices]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least '
                     'one file to output.')

    assert_inputs_exist(parser,
                        [args.input, args.bvals, args.bvecs, args.frf_file])
    assert_outputs_exist(parser, args, arglist)

    nbr_processes = args.nbr_processes
    parallel = True
    if nbr_processes is not None:
        if nbr_processes <= 0:
            nbr_processes = None
        elif nbr_processes == 1:
            parallel = False

    full_frf = np.loadtxt(args.frf_file)

    if not full_frf.shape[0] == 4:
        raise ValueError('FRF file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')

    frf = full_frf[0:3]
    mean_b0_val = full_frf[3]

    vol = nib.load(args.input)
    data = vol.get_data()

    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)

    # Raise warning for sh order if there is not enough DWIs
    if data.shape[-1] < (args.sh_order + 1) * (args.sh_order + 2) / 2:
        warnings.warn(
            'We recommend having at least {} unique DWIs volumes, but you '
            'currently have {} volumes. Try lowering the parameter --sh_order '
            'in case of non convergence.'.format(
                (args.sh_order + 1) * (args.sh_order + 2) / 2, data.shape[-1]))

    reg_sphere = get_sphere('symmetric362')
    peaks_sphere = get_sphere('symmetric724')

    csd_model = ConstrainedSphericalDeconvModel(gtab, (frf, mean_b0_val),
                                                reg_sphere=reg_sphere,
                                                sh_order=args.sh_order)

    peaks_csd = peaks_from_model(model=csd_model,
                                 data=data,
                                 sphere=peaks_sphere,
                                 relative_peak_threshold=.5,
                                 min_separation_angle=25,
                                 mask=mask,
                                 return_sh=True,
                                 sh_basis_type=args.sh_basis,
                                 sh_order=args.sh_order,
                                 normalize_peaks=True,
                                 parallel=parallel,
                                 nbr_processes=nbr_processes)

    if args.fodf:
        nib.save(
            nib.Nifti1Image(peaks_csd.shm_coeff.astype(np.float32),
                            vol.affine), args.fodf)

    if args.peaks:
        nib.save(
            nib.Nifti1Image(reshape_peaks_for_visualization(peaks_csd),
                            vol.affine), args.peaks)

    if args.peak_indices:
        nib.save(nib.Nifti1Image(peaks_csd.peak_indices, vol.affine),
                 args.peak_indices)
Exemplo n.º 5
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.fa = args.fa or 'fa.nii.gz'
        args.ga = args.ga or 'ga.nii.gz'
        args.rgb = args.rgb or 'rgb.nii.gz'
        args.md = args.md or 'md.nii.gz'
        args.ad = args.ad or 'ad.nii.gz'
        args.rd = args.rd or 'rd.nii.gz'
        args.mode = args.mode or 'mode.nii.gz'
        args.norm = args.norm or 'tensor_norm.nii.gz'
        args.tensor = args.tensor or 'tensor.nii.gz'
        args.evecs = args.evecs or 'tensor_evecs.nii.gz'
        args.evals = args.evals or 'tensor_evals.nii.gz'
        args.residual = args.residual or 'dti_residual.nii.gz'
        args.p_i_signal =\
            args.p_i_signal or 'physically_implausible_signals_mask.nii.gz'
        args.pulsation = args.pulsation or 'pulsation_and_misalignment.nii.gz'

    outputs = [args.fa, args.ga, args.rgb, args.md, args.ad, args.rd,
               args.mode, args.norm, args.tensor, args.evecs, args.evals,
               args.residual, args.p_i_signal, args.pulsation]
    if args.not_all and not any(outputs):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one metric to output.')

    assert_inputs_exist(
        parser, [args.input, args.bvals, args.bvecs], [args.mask])
    assert_outputs_exists(parser, args, outputs)

    img = nib.load(args.input)
    data = img.get_data()
    affine = img.get_affine()
    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)

    # Validate bvals and bvecs
    logging.info('Tensor estimation with the %s method...', args.method)
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    # Get tensors
    if args.method == 'restore':
        sigma = ne.estimate_sigma(data)
        tenmodel = TensorModel(gtab, fit_method=args.method, sigma=sigma,
                               min_signal=_get_min_nonzero_signal(data))
    else:
        tenmodel = TensorModel(gtab, fit_method=args.method,
                               min_signal=_get_min_nonzero_signal(data))

    tenfit = tenmodel.fit(data, mask)

    FA = fractional_anisotropy(tenfit.evals)
    FA[np.isnan(FA)] = 0
    FA = np.clip(FA, 0, 1)

    if args.tensor:
        # Get the Tensor values and format them for visualisation
        # in the Fibernavigator.
        tensor_vals = lower_triangular(tenfit.quadratic_form)
        correct_order = [0, 1, 3, 2, 4, 5]
        tensor_vals_reordered = tensor_vals[..., correct_order]
        fiber_tensors = nib.Nifti1Image(
            tensor_vals_reordered.astype(np.float32), affine)
        nib.save(fiber_tensors, args.tensor)

    if args.fa:
        fa_img = nib.Nifti1Image(FA.astype(np.float32), affine)
        nib.save(fa_img, args.fa)

    if args.ga:
        GA = geodesic_anisotropy(tenfit.evals)
        GA[np.isnan(GA)] = 0

        ga_img = nib.Nifti1Image(GA.astype(np.float32), affine)
        nib.save(ga_img, args.ga)

    if args.rgb:
        RGB = color_fa(FA, tenfit.evecs)
        rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine)
        nib.save(rgb_img, args.rgb)

    if args.md:
        MD = mean_diffusivity(tenfit.evals)
        md_img = nib.Nifti1Image(MD.astype(np.float32), affine)
        nib.save(md_img, args.md)

    if args.ad:
        AD = axial_diffusivity(tenfit.evals)
        ad_img = nib.Nifti1Image(AD.astype(np.float32), affine)
        nib.save(ad_img, args.ad)

    if args.rd:
        RD = radial_diffusivity(tenfit.evals)
        rd_img = nib.Nifti1Image(RD.astype(np.float32), affine)
        nib.save(rd_img, args.rd)

    if args.mode:
        # Compute tensor mode
        inter_mode = dipy_mode(tenfit.quadratic_form)

        # Since the mode computation can generate NANs when not masked,
        # we need to remove them.
        non_nan_indices = np.isfinite(inter_mode)
        mode = np.zeros(inter_mode.shape)
        mode[non_nan_indices] = inter_mode[non_nan_indices]

        mode_img = nib.Nifti1Image(mode.astype(np.float32), affine)
        nib.save(mode_img, args.mode)

    if args.norm:
        NORM = norm(tenfit.quadratic_form)
        norm_img = nib.Nifti1Image(NORM.astype(np.float32), affine)
        nib.save(norm_img, args.norm)

    if args.evecs:
        evecs = tenfit.evecs.astype(np.float32)
        evecs_img = nib.Nifti1Image(evecs, affine)
        nib.save(evecs_img, args.evecs)

        # save individual e-vectors also
        e1_img = nib.Nifti1Image(evecs[..., 0], affine)
        e2_img = nib.Nifti1Image(evecs[..., 1], affine)
        e3_img = nib.Nifti1Image(evecs[..., 2], affine)

        nib.save(e1_img, add_filename_suffix(args.evecs, '_v1'))
        nib.save(e2_img, add_filename_suffix(args.evecs, '_v2'))
        nib.save(e3_img, add_filename_suffix(args.evecs, '_v3'))

    if args.evals:
        evals = tenfit.evals.astype(np.float32)
        evals_img = nib.Nifti1Image(evals, affine)
        nib.save(evals_img, args.evals)

        # save individual e-values also
        e1_img = nib.Nifti1Image(evals[..., 0], affine)
        e2_img = nib.Nifti1Image(evals[..., 1], affine)
        e3_img = nib.Nifti1Image(evals[..., 2], affine)

        nib.save(e1_img, add_filename_suffix(args.evals, '_e1'))
        nib.save(e2_img, add_filename_suffix(args.evals, '_e2'))
        nib.save(e3_img, add_filename_suffix(args.evals, '_e3'))

    if args.p_i_signal:
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1, keepdims=True)
        DWI = data[..., ~gtab.b0s_mask]
        pis_mask = np.max(S0 < DWI, axis=-1)

        if args.mask is not None:
            pis_mask *= mask

        pis_img = nib.Nifti1Image(pis_mask.astype(np.int16), affine)
        nib.save(pis_img, args.p_i_signal)

    if args.pulsation:
        STD = np.std(data[..., ~gtab.b0s_mask], axis=-1)

        if args.mask is not None:
            STD *= mask

        std_img = nib.Nifti1Image(STD.astype(np.float32), affine)
        nib.save(std_img, add_filename_suffix(args.pulsation, '_std_dwi'))

        if np.sum(gtab.b0s_mask) <= 1:
            logger.info('Not enough b=0 images to output standard '
                        'deviation map')
        else:
            if len(np.where(gtab.b0s_mask)) == 2:
                logger.info('Only two b=0 images. Be careful with the '
                            'interpretation of this std map')

            STD = np.std(data[..., gtab.b0s_mask], axis=-1)

            if args.mask is not None:
                STD *= mask

            std_img = nib.Nifti1Image(STD.astype(np.float32), affine)
            nib.save(std_img, add_filename_suffix(args.pulsation, '_std_b0'))

    if args.residual:
        if args.mask is None:
            logger.info("Outlier detection will not be performed, since no "
                        "mask was provided.")
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1)
        data_p = tenfit.predict(gtab, S0)
        R = np.mean(np.abs(data_p[..., ~gtab.b0s_mask] -
                           data[..., ~gtab.b0s_mask]), axis=-1)

        if args.mask is not None:
            R *= mask

        R_img = nib.Nifti1Image(R.astype(np.float32), affine)
        nib.save(R_img, args.residual)

        R_k = np.zeros(data.shape[-1])  # mean residual per DWI
        std = np.zeros(data.shape[-1])  # std residual per DWI
        q1 = np.zeros(data.shape[-1])   # first quartile
        q3 = np.zeros(data.shape[-1])   # third quartile
        iqr = np.zeros(data.shape[-1])  # interquartile
        for i in range(data.shape[-1]):
            x = np.abs(data_p[..., i] - data[..., i])[mask]
            R_k[i] = np.mean(x)
            std[i] = np.std(x)
            q3[i], q1[i] = np.percentile(x, [75, 25])
            iqr[i] = q3[i] - q1[i]

            # Outliers are observations that fall below Q1 - 1.5(IQR) or
            # above Q3 + 1.5(IQR) We check if a volume is an outlier only if
            # we have a mask, else we are biased.
            if args.mask is not None and R_k[i] < (q1[i] - 1.5 * iqr[i]) \
                    or R_k[i] > (q3[i] + 1.5 * iqr[i]):
                logger.warning('WARNING: Diffusion-Weighted Image i=%s is an '
                               'outlier', i)

        residual_basename, _ = split_name_with_nii(args.residual)
        res_stats_basename = residual_basename + ".npy"
        np.save(add_filename_suffix(
            res_stats_basename, "_mean_residuals"), R_k)
        np.save(add_filename_suffix(res_stats_basename, "_q1_residuals"), q1)
        np.save(add_filename_suffix(res_stats_basename, "_q3_residuals"), q3)
        np.save(add_filename_suffix(res_stats_basename, "_iqr_residuals"), iqr)
        np.save(add_filename_suffix(res_stats_basename, "_std_residuals"), std)

        # To do: I would like to have an error bar with q1 and q3.
        # Now, q1 acts as a std
        dwi = np.arange(R_k[~gtab.b0s_mask].shape[0])
        plt.bar(dwi, R_k[~gtab.b0s_mask], 0.75,
                color='y', yerr=q1[~gtab.b0s_mask])
        plt.xlabel('DW image')
        plt.ylabel('Mean residuals +- q1')
        plt.title('Residuals')
        plt.savefig(residual_basename + '_residuals_stats.png')
Exemplo n.º 6
0
def compute_sh_coefficients(dwi,
                            gradient_table,
                            sh_order=4,
                            basis_type='descoteaux07',
                            smooth=0.006,
                            use_attenuation=False,
                            force_b0_threshold=False,
                            mask=None,
                            sphere=None):
    """Fit a diffusion signal with spherical harmonics coefficients.

    Parameters
    ----------
    dwi : nib.Nifti1Image object
        Diffusion signal as weighted images (4D).
    gradient_table : GradientTable
        Dipy object that contains all bvals and bvecs.
    sh_order : int, optional
        SH order to fit, by default 4.
    smooth : float, optional
        Lambda-regularization coefficient in the SH fit, by default 0.006.
    basis_type: str
        Either 'tournier07' or 'descoteaux07'
    use_attenuation: bool, optional
        If true, we will use DWI attenuation. [False]
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.
    mask: nib.Nifti1Image object, optional
        Binary mask. Only data inside the mask will be used for computations
        and reconstruction.
    sphere: Sphere
        Dipy object. If not provided, will use Sphere(xyz=bvecs).

    Returns
    -------
    sh_coeffs : np.ndarray with shape (X, Y, Z, #coeffs)
        Spherical harmonics coefficients at every voxel. The actual number
        of coefficients depends on `sh_order`.
    """

    # Extracting infos
    b0_mask = gradient_table.b0s_mask
    bvecs = gradient_table.bvecs
    bvals = gradient_table.bvals

    # Checks
    if not is_normalized_bvecs(bvecs):
        logging.warning("Your b-vectors do not seem normalized...")
        bvecs = normalize_bvecs(bvecs)
    check_b0_threshold(force_b0_threshold, bvals.min())

    # Ensure that this is on a single shell.
    shell_values, _ = identify_shells(bvals)
    shell_values.sort()
    if force_b0_threshold:
        b0_threshold = bvals.min()
    else:
        b0_threshold = DEFAULT_B0_THRESHOLD
    if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold:
        raise ValueError("Can only work on single shell signals.")

    # Keeping b0-based infos
    bvecs = bvecs[np.logical_not(b0_mask)]
    weights = dwi[..., np.logical_not(b0_mask)]

    # Compute attenuation using the b0.
    if use_attenuation:
        b0 = dwi[..., b0_mask].mean(axis=3)
        weights = compute_dwi_attenuation(weights, b0)

    # Get cartesian coords from bvecs
    if sphere is None:
        sphere = Sphere(xyz=bvecs)

    # Fit SH
    sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth)

    # Apply mask
    if mask is not None:
        sh *= mask[..., None]

    return sh
Exemplo n.º 7
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    img = nib.load(args.input)
    data = img.get_fdata()

    print('\ndata shape ({}, {}, {}, {})'.format(data.shape[0], data.shape[1], data.shape[2], data.shape[3]))
    print('total voxels {}'.format(np.prod(data.shape[:3])))

    # remove negatives
    print('\ncliping negative ({} voxels, {:.2f} % of total)'.format((data<0).sum(),100*(data<0).sum()/float(np.prod(data.shape[:3]))))
    data = np.clip(data, 0, np.inf)


    affine = img.affine
    if args.mask is None:
        mask = None
        masksum = np.prod(data.shape[:3])
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)
        masksum = mask.sum()

    print('\nMask has {} voxels, {:.2f} % of total'.format(masksum,100*masksum/float(np.prod(data.shape[:3]))))

    # Validate bvals and bvecs
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        print('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)


    # detect unique b-shell and assign shell id to each volume
    # sort bvals to get monotone increasing bvalue
    bvals_argsort = np.argsort(bvals)
    bvals_sorted = bvals[bvals_argsort]

    b_shell_threshold = 25.
    unique_bvalues = []
    shell_idx = []

    unique_bvalues.append(bvals_sorted[0])
    shell_idx.append(0)
    for newb in bvals_sorted[1:]:
        # check if volume is in existing shell
        done = False
        for i,b in enumerate(unique_bvalues):
            if (newb - b_shell_threshold < b) and (newb + b_shell_threshold > b):
                shell_idx.append(i)
                done = True
        if not done:
            unique_bvalues.append(newb)
            shell_idx.append(i+1)

    unique_bvalues = np.array(unique_bvalues)
    # un-sort shells
    shells = np.zeros_like(bvals)
    shells[bvals_argsort] = shell_idx



    print('\nWe have {} shells'.format(len(unique_bvalues)))
    print('with b-values {}\n'.format(unique_bvalues))

    for i in range(len(unique_bvalues)):
        shell_b = bvals[shells==i]
        print('shell {}: n = {}, min/max {} {}'.format(i, len(shell_b), shell_b.min(), shell_b.max()))




    # Get tensors
    method = 'WLS'
    min_signal = 1e-16
    print('\nUsing fitting method {}'.format(method))
    # print('Using minimum signal = {}'.format(min_signal)

    b0_thr = bvals.min() + 10
    print('\nassuming existence of b0 (thr = {})\n'.format(b0_thr))



    fas = []
    mds = []
    lams_max = []
    lams_min = []
    delta_S = []
    raw_signal = []
    for i in range(len(unique_bvalues)-1):
        # max_shell = i+1
        print('fitting using {} th shells (bmax = {})'.format(i+2, bvals[shells==i+1].max()))

        # restricted gtab
        # gtab = gradient_table(bvals[shells <= i+1], bvecs[shells <= i+1], b0_threshold=b0_thr)
        gtab = gradient_table(bvals[np.logical_or(shells == i+1, shells == 0)], bvecs[np.logical_or(shells == i+1, shells == 0)], b0_threshold=b0_thr)

        tenmodel = TensorModel(gtab, fit_method=method, min_signal=min_signal)

        tenfit = tenmodel.fit(data[..., np.logical_or(shells == i+1, shells == 0)], mask)
        raw_signal.append(data[..., np.logical_or(shells == i+1, shells == 0)][mask].mean(axis=1))

        evalmax = np.max(tenfit.evals, axis=3)
        evalmin = np.min(tenfit.evals, axis=3)

        evalmax[np.isnan(evalmax)] = 0
        evalmin[np.isnan(evalmin)] = 0
        evalmax[np.isinf(evalmax)] = 0
        evalmin[np.isinf(evalmin)] = 0

        weird_contrast = np.exp(-unique_bvalues[i+1]*evalmin) - np.exp(-unique_bvalues[i+1]*evalmax)


        mds.append(tenfit.md[mask])
        fas.append(tenfit.fa[mask])
        lams_max.append(evalmax[mask])
        lams_min.append(evalmin[mask])
        delta_S.append(weird_contrast[mask])


    bmaxs = np.array([bvals[shells==i+1].max() for i in range(len(unique_bvalues)-1)])


    names = ['FA',
             'MD',
             'eval_max',
             'eval_min',
             'delta_S',
             'eval_max_minus_eval_min',
             'raw_signal']

    units = ['a.u.',
             'mm^2/s',
             'mm^2/s',
             'mm^2/s',
             'contrast (a.u.)',
             'mm^2/s',
             'raw signal (a.u.)']


    datas = [np.array(fas).mean(axis=1),
             np.array(mds).mean(axis=1),
             np.array(lams_max).mean(axis=1),
             np.array(lams_min).mean(axis=1),
             np.array(delta_S).mean(axis=1),
             (np.array(lams_max)-np.array(lams_min)).mean(axis=1),
             np.array(raw_signal).mean(axis=1)]


    for i in range(len(names)):
        plt.figure()
        plt.plot(bmaxs, datas[i])
        plt.title(names[i])
        plt.xlabel('bval (s/mm^2)')
        plt.ylabel(units[i])

        plt.savefig('./roi_plot_'+names[i]+'.png', dpi=150)


    plt.show()
Exemplo n.º 8
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.fa = args.fa or 'fa.nii.gz'
        args.ga = args.ga or 'ga.nii.gz'
        args.rgb = args.rgb or 'rgb.nii.gz'
        args.md = args.md or 'md.nii.gz'
        args.ad = args.ad or 'ad.nii.gz'
        args.rd = args.rd or 'rd.nii.gz'
        args.mode = args.mode or 'mode.nii.gz'
        args.norm = args.norm or 'tensor_norm.nii.gz'
        args.tensor = args.tensor or 'tensor.nii.gz'
        args.evecs = args.evecs or 'tensor_evecs.nii.gz'
        args.evals = args.evals or 'tensor_evals.nii.gz'
        args.residual = args.residual or 'dti_residual.nii.gz'
        args.p_i_signal =\
            args.p_i_signal or 'physically_implausible_signals_mask.nii.gz'
        args.pulsation = args.pulsation or 'pulsation_and_misalignment.nii.gz'

    outputs = [args.fa, args.ga, args.rgb, args.md, args.ad, args.rd,
               args.mode, args.norm, args.tensor, args.evecs, args.evals,
               args.residual, args.p_i_signal, args.pulsation]
    if args.not_all and not any(outputs):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one metric to output.')

    assert_inputs_exist(
        parser, [args.input, args.bvals, args.bvecs], args.mask)
    assert_outputs_exist(parser, args, outputs)

    img = nib.load(args.input)
    data = img.get_data()
    affine = img.get_affine()
    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)

    # Validate bvals and bvecs
    logging.info('Tensor estimation with the %s method...', args.method)
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    # Get tensors
    if args.method == 'restore':
        sigma = ne.estimate_sigma(data)
        tenmodel = TensorModel(gtab, fit_method=args.method, sigma=sigma,
                               min_signal=_get_min_nonzero_signal(data))
    else:
        tenmodel = TensorModel(gtab, fit_method=args.method,
                               min_signal=_get_min_nonzero_signal(data))

    tenfit = tenmodel.fit(data, mask)

    FA = fractional_anisotropy(tenfit.evals)
    FA[np.isnan(FA)] = 0
    FA = np.clip(FA, 0, 1)

    if args.tensor:
        # Get the Tensor values and format them for visualisation
        # in the Fibernavigator.
        tensor_vals = lower_triangular(tenfit.quadratic_form)
        correct_order = [0, 1, 3, 2, 4, 5]
        tensor_vals_reordered = tensor_vals[..., correct_order]
        fiber_tensors = nib.Nifti1Image(
            tensor_vals_reordered.astype(np.float32), affine)
        nib.save(fiber_tensors, args.tensor)

    if args.fa:
        fa_img = nib.Nifti1Image(FA.astype(np.float32), affine)
        nib.save(fa_img, args.fa)

    if args.ga:
        GA = geodesic_anisotropy(tenfit.evals)
        GA[np.isnan(GA)] = 0

        ga_img = nib.Nifti1Image(GA.astype(np.float32), affine)
        nib.save(ga_img, args.ga)

    if args.rgb:
        RGB = color_fa(FA, tenfit.evecs)
        rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine)
        nib.save(rgb_img, args.rgb)

    if args.md:
        MD = mean_diffusivity(tenfit.evals)
        md_img = nib.Nifti1Image(MD.astype(np.float32), affine)
        nib.save(md_img, args.md)

    if args.ad:
        AD = axial_diffusivity(tenfit.evals)
        ad_img = nib.Nifti1Image(AD.astype(np.float32), affine)
        nib.save(ad_img, args.ad)

    if args.rd:
        RD = radial_diffusivity(tenfit.evals)
        rd_img = nib.Nifti1Image(RD.astype(np.float32), affine)
        nib.save(rd_img, args.rd)

    if args.mode:
        # Compute tensor mode
        inter_mode = dipy_mode(tenfit.quadratic_form)

        # Since the mode computation can generate NANs when not masked,
        # we need to remove them.
        non_nan_indices = np.isfinite(inter_mode)
        mode = np.zeros(inter_mode.shape)
        mode[non_nan_indices] = inter_mode[non_nan_indices]

        mode_img = nib.Nifti1Image(mode.astype(np.float32), affine)
        nib.save(mode_img, args.mode)

    if args.norm:
        NORM = norm(tenfit.quadratic_form)
        norm_img = nib.Nifti1Image(NORM.astype(np.float32), affine)
        nib.save(norm_img, args.norm)

    if args.evecs:
        evecs = tenfit.evecs.astype(np.float32)
        evecs_img = nib.Nifti1Image(evecs, affine)
        nib.save(evecs_img, args.evecs)

        # save individual e-vectors also
        e1_img = nib.Nifti1Image(evecs[..., 0], affine)
        e2_img = nib.Nifti1Image(evecs[..., 1], affine)
        e3_img = nib.Nifti1Image(evecs[..., 2], affine)

        nib.save(e1_img, add_filename_suffix(args.evecs, '_v1'))
        nib.save(e2_img, add_filename_suffix(args.evecs, '_v2'))
        nib.save(e3_img, add_filename_suffix(args.evecs, '_v3'))

    if args.evals:
        evals = tenfit.evals.astype(np.float32)
        evals_img = nib.Nifti1Image(evals, affine)
        nib.save(evals_img, args.evals)

        # save individual e-values also
        e1_img = nib.Nifti1Image(evals[..., 0], affine)
        e2_img = nib.Nifti1Image(evals[..., 1], affine)
        e3_img = nib.Nifti1Image(evals[..., 2], affine)

        nib.save(e1_img, add_filename_suffix(args.evals, '_e1'))
        nib.save(e2_img, add_filename_suffix(args.evals, '_e2'))
        nib.save(e3_img, add_filename_suffix(args.evals, '_e3'))

    if args.p_i_signal:
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1, keepdims=True)
        DWI = data[..., ~gtab.b0s_mask]
        pis_mask = np.max(S0 < DWI, axis=-1)

        if args.mask is not None:
            pis_mask *= mask

        pis_img = nib.Nifti1Image(pis_mask.astype(np.int16), affine)
        nib.save(pis_img, args.p_i_signal)

    if args.pulsation:
        STD = np.std(data[..., ~gtab.b0s_mask], axis=-1)

        if args.mask is not None:
            STD *= mask

        std_img = nib.Nifti1Image(STD.astype(np.float32), affine)
        nib.save(std_img, add_filename_suffix(args.pulsation, '_std_dwi'))

        if np.sum(gtab.b0s_mask) <= 1:
            logger.info('Not enough b=0 images to output standard '
                        'deviation map')
        else:
            if len(np.where(gtab.b0s_mask)) == 2:
                logger.info('Only two b=0 images. Be careful with the '
                            'interpretation of this std map')

            STD = np.std(data[..., gtab.b0s_mask], axis=-1)

            if args.mask is not None:
                STD *= mask

            std_img = nib.Nifti1Image(STD.astype(np.float32), affine)
            nib.save(std_img, add_filename_suffix(args.pulsation, '_std_b0'))

    if args.residual:
        # Mean residual image
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1)
        data_p = tenfit.predict(gtab, S0)
        R = np.mean(np.abs(data_p[..., ~gtab.b0s_mask] -
                           data[..., ~gtab.b0s_mask]), axis=-1)

        if args.mask is not None:
            R *= mask

        R_img = nib.Nifti1Image(R.astype(np.float32), affine)
        nib.save(R_img, args.residual)

        # Each volume's residual statistics
        if args.mask is None:
            logger.info("Outlier detection will not be performed, since no "
                        "mask was provided.")
        stats = [dict.fromkeys(['label', 'mean', 'iqr', 'cilo', 'cihi', 'whishi',
                                'whislo', 'fliers', 'q1', 'med', 'q3'], [])
                 for i in range(data.shape[-1])]  # stats with format for boxplots
        # Note that stats will be computed manually and plotted using bxp
        # but could be computed using stats = cbook.boxplot_stats
        # or pyplot.boxplot(x)
        R_k = np.zeros(data.shape[-1])    # mean residual per DWI
        std = np.zeros(data.shape[-1])  # std residual per DWI
        q1 = np.zeros(data.shape[-1])   # first quartile per DWI
        q3 = np.zeros(data.shape[-1])   # third quartile per DWI
        iqr = np.zeros(data.shape[-1])  # interquartile per DWI
        percent_outliers = np.zeros(data.shape[-1])
        nb_voxels = np.count_nonzero(mask)
        for k in range(data.shape[-1]):
            x = np.abs(data_p[..., k] - data[..., k])[mask]
            R_k[k] = np.mean(x)
            std[k] = np.std(x)
            q3[k], q1[k] = np.percentile(x, [75, 25])
            iqr[k] = q3[k] - q1[k]
            stats[k]['med'] = (q1[k] + q3[k]) / 2
            stats[k]['mean'] = R_k[k]
            stats[k]['q1'] = q1[k]
            stats[k]['q3'] = q3[k]
            stats[k]['whislo'] = q1[k] - 1.5 * iqr[k]
            stats[k]['whishi'] = q3[k] + 1.5 * iqr[k]
            stats[k]['label'] = k

            # Outliers are observations that fall below Q1 - 1.5(IQR) or
            # above Q3 + 1.5(IQR) We check if a voxel is an outlier only if
            # we have a mask, else we are biased.
            if args.mask is not None:
                outliers = (x < stats[k]['whislo']) | (x > stats[k]['whishi'])
                percent_outliers[k] = np.sum(outliers)/nb_voxels*100
                # What would be our definition of too many outliers?
                # Maybe mean(all_means)+-3SD?
                # Or we let people choose based on the figure.
                # if percent_outliers[k] > ???? :
                #    logger.warning('   Careful! Diffusion-Weighted Image'
                #                   ' i=%s has %s %% outlier voxels',
                #                   k, percent_outliers[k])

        # Saving all statistics as npy values
        residual_basename, _ = split_name_with_nii(args.residual)
        res_stats_basename = residual_basename + ".npy"
        np.save(add_filename_suffix(
            res_stats_basename, "_mean_residuals"), R_k)
        np.save(add_filename_suffix(res_stats_basename, "_q1_residuals"), q1)
        np.save(add_filename_suffix(res_stats_basename, "_q3_residuals"), q3)
        np.save(add_filename_suffix(res_stats_basename, "_iqr_residuals"), iqr)
        np.save(add_filename_suffix(res_stats_basename, "_std_residuals"), std)

        # Showing results in graph
        if args.mask is None:
            fig, axe = plt.subplots(nrows=1, ncols=1, squeeze=False)
        else:
            fig, axe = plt.subplots(nrows=1, ncols=2, squeeze=False,
                                    figsize=[10, 4.8])
            # Default is [6.4, 4.8]. Increasing width to see better.

        medianprops = dict(linestyle='-', linewidth=2.5, color='firebrick')
        meanprops = dict(linestyle='-', linewidth=2.5, color='green')
        axe[0, 0].bxp(stats, showmeans=True, meanline=True, showfliers=False,
                      medianprops=medianprops, meanprops=meanprops)
        axe[0, 0].set_xlabel('DW image')
        axe[0, 0].set_ylabel('Residuals per DWI volume. Red is median,\n'
                             'green is mean. Whiskers are 1.5*interquartile')
        axe[0, 0].set_title('Residuals')
        axe[0, 0].set_xticks(range(0, q1.shape[0], 5))
        axe[0, 0].set_xticklabels(range(0, q1.shape[0], 5))

        if args.mask is not None:
            axe[0, 1].plot(range(data.shape[-1]), percent_outliers)
            axe[0, 1].set_xticks(range(0, q1.shape[0], 5))
            axe[0, 1].set_xticklabels(range(0, q1.shape[0], 5))
            axe[0, 1].set_xlabel('DW image')
            axe[0, 1].set_ylabel('Percentage of outlier voxels')
            axe[0, 1].set_title('Outliers')
        plt.savefig(residual_basename + '_residuals_stats.png')
Exemplo n.º 9
0
def main():
    parser = buildArgsParser()
    args = parser.parse_args()
    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    if os.path.isfile(args.output):
        if args.overwrite:
            logging.debug('Overwriting "{0}".'.format(args.output))
        else:
            parser.error(
                '"{0}" already exists! Use -f to overwrite it.'.format(
                    args.output))

    bvals_source, bvecs_source = read_bvals_bvecs(args.bvals_source,
                                                  args.bvecs_source)
    if not is_normalized_bvecs(bvecs_source):
        logging.warning('Your source b-vectors do not seem normalized...')
        bvecs_source = normalize_bvecs(bvecs_source)
    if bvals_source.min() > 0:
        if bvals_source.min() > 20:
            raise ValueError('The minimal source b-value is greater than 20.' +
                             ' This is highly suspicious. Please check ' +
                             'your data to ensure everything is correct.\n' +
                             'Value found: {0}'.format(bvals_source.min()))
        else:
            logging.warning('Warning: no b=0 image. Setting b0_threshold to ' +
                            'bvals.min() = {0}'.format(bvals_source.min()))
    gtab_source = gradient_table(bvals_source,
                                 bvecs_source,
                                 b0_threshold=bvals_source.min())

    bvals_target, bvecs_target = read_bvals_bvecs(args.bvals_target,
                                                  args.bvecs_target)
    if not is_normalized_bvecs(bvecs_target):
        logging.warning('Your output b-vectors do not seem normalized...')
        bvecs_target = normalize_bvecs(bvecs_target)
    if bvals_target.min() != 0:
        if bvals_target.min() > 20:
            raise ValueError('The minimal target b-value is greater than 20.' +
                             ' This is highly suspicious. Please check ' +
                             'your data to ensure everything is correct.\n' +
                             'Value found: {0}'.format(bvals_target.min()))
        else:
            logging.warning('Warning: no b=0 image. Setting b0_threshold to ' +
                            'bvals.min() = {0}'.format(bvals_target.min()))
    gtab_target = gradient_table(bvals_target,
                                 bvecs_target,
                                 b0_threshold=bvals_target.min())

    dwi_img_source = nib.load(args.input)
    data_source = dwi_img_source.get_data()

    data_target = np.zeros(list(data_source.shape)[:-1] + [len(bvals_target)])

    if args.mask is not None:
        mask = nib.load(args.mask).get_data().astype('bool')
    else:
        mask = np.ones_like(data_source[..., 0], dtype=np.bool)

    mapmri = MapmriModel(gtab_source,
                         radial_order=args.radial_order,
                         lambd=args.lambd,
                         anisotropic_scaling=args.anisotropic_scaling,
                         eap_cons=args.eap_cons,
                         bmax_threshold=args.bmax_threshold)

    nbr_voxels_total = mask.sum()
    nbr_voxels_done = 0
    for idx in np.ndindex(mask.shape):
        if mask[idx] > 0:
            if nbr_voxels_done % 100 == 0:
                logging.warning("{}/{} voxels dones".format(
                    nbr_voxels_done, nbr_voxels_total))

            fit = mapmri.fit(data_source[idx], mask=mask)
            data_target[idx] = fit.predict(gtab_target)
            nbr_voxels_done += 1

    # header information is updated accordingly by nibabel
    out_img = nib.Nifti1Image(data_target, dwi_img_source.get_affine(),
                              dwi_img_source.get_header())
    out_img.to_filename(args.output)
Exemplo n.º 10
0
def compute_fodf(data,
                 bvals,
                 bvecs,
                 full_frf,
                 sh_order=8,
                 nbr_processes=None,
                 mask=None,
                 sh_basis='descoteaux07',
                 return_sh=True,
                 n_peaks=5,
                 force_b0_threshold=False):
    """
     Script to compute Constrained Spherical Deconvolution (CSD) fiber ODFs.

     By default, will output all possible files, using default names. Specific
     names can be specified using the file flags specified in the "File flags"
     section.

    If --not_all is set, only the files specified explicitly by the flags
    will be output.

    See [Tournier et al. NeuroImage 2007] and [Cote et al Tractometer MedIA 2013]
    for quantitative comparisons with Sharpening Deconvolution Transform (SDT).

    Parameters
    ----------
    data: ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals: ndarray
        1D bvals array with shape (N,)
    bvecs: ndarray
        2D (normalized) bvecs array with shape (N, 3)
    full_frf: ndarray
        frf data, ex, loaded from a frf_file, with shape (4,).
    sh_order: int, optional
        SH order used for the CSD. (Default: 8)
    nbr_processes: int, optional
        Number of sub processes to start. Default = none, i.e use the cpu count.
        If 0, use all processes.
    mask: ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction. Useful if no white matter mask is
        available.
    sh_basis: str, optional
        Spherical harmonics basis used for the SH coefficients.Must be either
        'descoteaux07' or 'tournier07' (default 'descoteaux07')
        - 'descoteaux07': SH basis from the Descoteaux et al. MRM 2007 paper
        - 'tournier07': SH basis from the Tournier et al. NeuroImage 2007 paper.
    return_sh: bool, optional
        If true, returns the sh.
    n_peaks: int, optional
        Nb of peaks for the fodf. Default: copied dipy's default, i.e. 5.
    force_b0_threshold: bool, optional
        If True, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    peaks_csd: PeaksAndMetrics
        An object with ``gfa``, ``peak_directions``, ``peak_values``,
        ``peak_indices``, ``odf``, ``shm_coeffs`` as attributes
    """

    # Checking data and sh_order
    check_b0_threshold(force_b0_threshold, bvals.min())
    if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2:
        logging.warning(
            'We recommend having at least {} unique DWI volumes, but you '
            'currently have {} volumes. Try lowering the parameter sh_order '
            'in case of non convergence.'.format(
                (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))

    # Checking bvals, bvecs values and loading gtab
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    # Checking full_frf and separating it
    if not full_frf.shape[0] == 4:
        raise ValueError('FRF file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    frf = full_frf[0:3]
    mean_b0_val = full_frf[3]

    # Checking if we will use parallel processing
    parallel = True
    if nbr_processes is not None:
        if nbr_processes == 0:  # Will use all processed
            nbr_processes = None
        elif nbr_processes == 1:
            parallel = False
        elif nbr_processes < 0:
            raise ValueError('nbr_processes should be positive.')

    # Checking sh basis
    validate_sh_basis_choice(sh_basis)

    # Loading the spheres
    reg_sphere = get_sphere('symmetric362')
    peaks_sphere = get_sphere('symmetric724')

    # Computing CSD
    csd_model = ConstrainedSphericalDeconvModel(gtab, (frf, mean_b0_val),
                                                reg_sphere=reg_sphere,
                                                sh_order=sh_order)

    # Computing peaks. Run in parallel, using the default number of processes
    # (default: CPU count)
    peaks_csd = peaks_from_model(model=csd_model,
                                 data=data,
                                 sphere=peaks_sphere,
                                 relative_peak_threshold=.5,
                                 min_separation_angle=25,
                                 mask=mask,
                                 return_sh=return_sh,
                                 sh_basis_type=sh_basis,
                                 sh_order=sh_order,
                                 normalize_peaks=True,
                                 npeaks=n_peaks,
                                 parallel=parallel,
                                 nbr_processes=nbr_processes)

    return peaks_csd
Exemplo n.º 11
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.gfa = args.gfa or 'gfa.nii.gz'
        args.peaks = args.peaks or 'peaks.nii.gz'
        args.peak_indices = args.peak_indices or 'peaks_indices.nii.gz'
        args.sh = args.sh or 'sh.nii.gz'
        args.nufo = args.nufo or 'nufo.nii.gz'
        args.a_power = args.a_power or 'anisotropic_power.nii.gz'

    arglist = [
        args.gfa, args.peaks, args.peak_indices, args.sh, args.nufo,
        args.a_power
    ]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.input, args.bvals, args.bvecs])
    assert_outputs_exists(parser, args, arglist)

    nbr_processes = args.nbr_processes
    parallel = True
    if nbr_processes <= 0:
        nbr_processes = None
    elif nbr_processes == 1:
        parallel = False

    # Load data
    img = nib.load(args.input)
    data = img.get_data()
    affine = img.get_affine()

    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    if bvals.min() != 0:
        if bvals.min() > 20:
            raise ValueError(
                'The minimal bvalue is greater than 20. This is highly '
                'suspicious. Please check your data to ensure everything is '
                'correct.\nValue found: {0}'.format(bvals.min()))
        else:
            logging.warning(
                'Warning: no b=0 image. Setting b0_threshold to '
                'bvals.min() = %s', bvals.min())
            gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())
    else:
        gtab = gradient_table(bvals, bvecs)

    sphere = get_sphere('symmetric724')

    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)

    if args.use_qball:
        model = QballModel(gtab, sh_order=int(args.sh_order), smooth=0.006)
    else:
        model = CsaOdfModel(gtab, sh_order=int(args.sh_order), smooth=0.006)

    odfpeaks = peaks_from_model(model=model,
                                data=data,
                                sphere=sphere,
                                relative_peak_threshold=.5,
                                min_separation_angle=25,
                                mask=mask,
                                return_odf=False,
                                normalize_peaks=True,
                                return_sh=True,
                                sh_order=int(args.sh_order),
                                sh_basis_type=args.basis,
                                npeaks=5,
                                parallel=parallel,
                                nbr_processes=nbr_processes)

    if args.gfa:
        nib.save(nib.Nifti1Image(odfpeaks.gfa.astype(np.float32), affine),
                 args.gfa)

    if args.peaks:
        nib.save(
            nib.Nifti1Image(reshape_peaks_for_visualization(odfpeaks), affine),
            args.peaks)

    if args.peak_indices:
        nib.save(nib.Nifti1Image(odfpeaks.peak_indices, affine),
                 args.peak_indices)

    if args.sh:
        nib.save(
            nib.Nifti1Image(odfpeaks.shm_coeff.astype(np.float32), affine),
            args.sh)

    if args.nufo:
        peaks_count = (odfpeaks.peak_indices > -1).sum(3)
        nib.save(nib.Nifti1Image(peaks_count.astype(np.int32), affine),
                 args.nufo)

    if args.a_power:
        odf_a_power = anisotropic_power(odfpeaks.shm_coeff)
        nib.save(nib.Nifti1Image(odf_a_power.astype(np.float32), affine),
                 args.a_power)
Exemplo n.º 12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.gfa = args.gfa or 'gfa.nii.gz'
        args.peaks = args.peaks or 'peaks.nii.gz'
        args.peak_indices = args.peak_indices or 'peaks_indices.nii.gz'
        args.sh = args.sh or 'sh.nii.gz'
        args.nufo = args.nufo or 'nufo.nii.gz'
        args.a_power = args.a_power or 'anisotropic_power.nii.gz'

    arglist = [
        args.gfa, args.peaks, args.peak_indices, args.sh, args.nufo,
        args.a_power
    ]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one file to output.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec])
    assert_outputs_exist(parser, args, arglist)
    validate_nbr_processes(parser, args)

    nbr_processes = args.nbr_processes
    parallel = nbr_processes > 1

    # Load data
    img = nib.load(args.in_dwi)
    data = img.get_fdata(dtype=np.float32)

    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    sphere = get_sphere('symmetric724')

    mask = None
    if args.mask:
        mask = get_data_as_mask(nib.load(args.mask))

        # Sanity check on shape of mask
        if mask.shape != data.shape[:-1]:
            raise ValueError('Mask shape does not match data shape.')

    if args.use_qball:
        model = QballModel(gtab, sh_order=args.sh_order, smooth=DEFAULT_SMOOTH)
    else:
        model = CsaOdfModel(gtab,
                            sh_order=args.sh_order,
                            smooth=DEFAULT_SMOOTH)

    odfpeaks = peaks_from_model(model=model,
                                data=data,
                                sphere=sphere,
                                relative_peak_threshold=.5,
                                min_separation_angle=25,
                                mask=mask,
                                return_odf=False,
                                normalize_peaks=True,
                                return_sh=True,
                                sh_order=int(args.sh_order),
                                sh_basis_type=args.sh_basis,
                                npeaks=5,
                                parallel=parallel,
                                nbr_processes=nbr_processes)

    if args.gfa:
        nib.save(nib.Nifti1Image(odfpeaks.gfa.astype(np.float32), img.affine),
                 args.gfa)

    if args.peaks:
        nib.save(
            nib.Nifti1Image(reshape_peaks_for_visualization(odfpeaks),
                            img.affine), args.peaks)

    if args.peak_indices:
        nib.save(nib.Nifti1Image(odfpeaks.peak_indices, img.affine),
                 args.peak_indices)

    if args.sh:
        nib.save(
            nib.Nifti1Image(odfpeaks.shm_coeff.astype(np.float32), img.affine),
            args.sh)

    if args.nufo:
        peaks_count = (odfpeaks.peak_indices > -1).sum(3)
        nib.save(nib.Nifti1Image(peaks_count.astype(np.int32), img.affine),
                 args.nufo)

    if args.a_power:
        odf_a_power = anisotropic_power(odfpeaks.shm_coeff)
        nib.save(nib.Nifti1Image(odf_a_power.astype(np.float32), img.affine),
                 args.a_power)
Exemplo n.º 13
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    if not args.not_all:
        args.fodf = args.fodf or 'fodf.nii.gz'
        args.peaks = args.peaks or 'peaks.nii.gz'
        args.peak_indices = args.peak_indices or 'peak_indices.nii.gz'

    arglist = [args.fodf, args.peaks, args.peak_indices]
    if args.not_all and not any(arglist):
        parser.error('When using --not_all, you need to specify at least '
                     'one file to output.')

    assert_inputs_exist(parser, [args.input, args.bvals, args.bvecs])
    assert_outputs_exists(parser, args, arglist)

    nbr_processes = args.nbr_processes
    parallel = True
    if nbr_processes <= 0:
        nbr_processes = None
    elif nbr_processes == 1:
        parallel = False

    # Check for FRF filename
    base_odf_name, _ = split_name_with_nii(args.fodf)
    frf_filename = base_odf_name + '_frf.txt'
    if os.path.isfile(frf_filename) and not args.overwrite:
        parser.error('Cannot save frf file, "{0}" already exists. '
                     'Use -f to overwrite.'.format(frf_filename))

    vol = nib.load(args.input)
    data = vol.get_data()

    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if args.mask_wm is not None:
        wm_mask = nib.load(args.mask_wm).get_data().astype('bool')
    else:
        wm_mask = np.ones_like(data[..., 0], dtype=np.bool)
        logging.info(
            'No white matter mask specified! mask_data will be used instead, '
            'if it has been supplied. \nBe *VERY* careful about the '
            'estimation of the fiber response function for the CSD.')

    data_in_wm = applymask(data, wm_mask)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    if bvals.min() != 0:
        if bvals.min() > 20:
            raise ValueError(
                'The minimal bvalue is greater than 20. This is highly '
                'suspicious. Please check your data to ensure everything is '
                'correct.\nValue found: {}'.format(bvals.min()))
        else:
            logging.warning(
                'Warning: no b=0 image. Setting b0_threshold to '
                'bvals.min() = %s', bvals.min())
            gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())
    else:
        gtab = gradient_table(bvals, bvecs)

    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)

    # Raise warning for sh order if there is not enough DWIs
    if data.shape[-1] < (args.sh_order + 1) * (args.sh_order + 2) / 2:
        warnings.warn(
            'We recommend having at least %s unique DWIs volumes, but you '
            'currently have %s volumes. Try lowering the parameter --sh_order '
            'in case of non convergence.',
            (args.sh_order + 1) * (args.sh_order + 2) / 2), data.shape[-1]
    fa_thresh = args.fa_thresh

    # If threshold is too high, try lower until enough indices are found
    # estimating a response function with fa < 0.5 does not make sense
    nvox = 0
    while nvox < 300 and fa_thresh > 0.5:
        response, ratio, nvox = auto_response(gtab,
                                              data_in_wm,
                                              roi_center=args.roi_center,
                                              roi_radius=args.roi_radius,
                                              fa_thr=fa_thresh,
                                              return_number_of_voxels=True)

        logging.info('Number of indices is %s with threshold of %s', nvox,
                     fa_thresh)
        fa_thresh -= 0.05

        if fa_thresh <= 0:
            raise ValueError(
                'Could not find at least 300 voxels for estimating the frf!')

    logging.info('Found %s valid voxels for frf estimation.', nvox)

    response = list(response)
    logging.info('Response function is %s', response)

    if args.frf is not None:
        l01 = np.array(literal_eval(args.frf), dtype=np.float64)
        if not args.no_factor:
            l01 *= 10**-4

        response[0] = np.array([l01[0], l01[1], l01[1]])
        ratio = l01[1] / l01[0]

    logging.info("Eigenvalues for the frf of the input data are: %s",
                 response[0])
    logging.info("Ratio for smallest to largest eigen value is %s", ratio)
    np.savetxt(frf_filename, response[0])

    if not args.frf_only:
        reg_sphere = get_sphere('symmetric362')
        peaks_sphere = get_sphere('symmetric724')

        csd_model = ConstrainedSphericalDeconvModel(gtab,
                                                    response,
                                                    reg_sphere=reg_sphere,
                                                    sh_order=args.sh_order)

        peaks_csd = peaks_from_model(model=csd_model,
                                     data=data,
                                     sphere=peaks_sphere,
                                     relative_peak_threshold=.5,
                                     min_separation_angle=25,
                                     mask=mask,
                                     return_sh=True,
                                     sh_basis_type=args.basis,
                                     sh_order=args.sh_order,
                                     normalize_peaks=True,
                                     parallel=parallel,
                                     nbr_processes=nbr_processes)

        if args.fodf:
            nib.save(
                nib.Nifti1Image(peaks_csd.shm_coeff.astype(np.float32),
                                vol.affine), args.fodf)

        if args.peaks:
            nib.save(
                nib.Nifti1Image(reshape_peaks_for_visualization(peaks_csd),
                                vol.affine), args.peaks)

        if args.peak_indices:
            nib.save(nib.Nifti1Image(peaks_csd.peak_indices, vol.affine),
                     args.peak_indices)
Exemplo n.º 14
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    img = nib.load(args.input)
    data = img.get_data()

    bmax = int(args.bmax)

    print('\ndata shape ({}, {}, {}, {})'.format(data.shape[0], data.shape[1],
                                                 data.shape[2], data.shape[3]))
    print('total voxels {}'.format(np.prod(data.shape[:3])))

    # remove negatives
    print('\ncliping negative ({} voxels, {:.2f} % of total)'.format(
        (data < 0).sum(),
        100 * (data < 0).sum() / float(np.prod(data.shape[:3]))))
    data = np.clip(data, 0, np.inf)

    affine = img.affine
    if args.mask is None:
        mask = None
        masksum = np.prod(data.shape[:3])
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)
        masksum = mask.sum()

    print('\nMask has {} voxels, {:.2f} % of total'.format(
        masksum, 100 * masksum / float(np.prod(data.shape[:3]))))

    # Validate bvals and bvecs
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        print('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    # Get tensors
    method = 'WLS'
    min_signal = 1e-16
    print('\nUsing fitting method {}'.format(method))
    # print('Using minimum signal = {}'.format(min_signal)

    b0_thr = bvals.min() + 10
    print('\nassuming existence of b0 (thr = {})\n'.format(b0_thr))

    # restricted gtab
    gtab = gradient_table(bvals[bvals < bmax + 22],
                          bvecs[bvals < bmax + 22],
                          b0_threshold=b0_thr)

    tenmodel = TensorModel(gtab, fit_method=method, min_signal=min_signal)

    tenfit = tenmodel.fit(data[..., bvals < bmax + 22], mask)

    MD = tenfit.md
    FA = tenfit.fa

    evalmax = np.max(tenfit.evals, axis=3)
    invevalmax = evalmax**-1
    invevalmax[np.isnan(invevalmax)] = 0
    invevalmax[np.isinf(invevalmax)] = 0

    evalmin = np.min(tenfit.evals, axis=3)

    weird_contrast = np.exp(-bmax * evalmin) - np.exp(-bmax * evalmax)

    invMD = MD**-1
    invMD[np.isnan(invMD)] = 0
    invMD[np.isinf(invMD)] = 0

    nib.nifti1.Nifti1Image(MD,
                           img.affine).to_filename('./MD_bmax_{}'.format(bmax))
    nib.nifti1.Nifti1Image(invMD, img.affine).to_filename(
        './invMD_bmax_{}'.format(bmax))
    nib.nifti1.Nifti1Image(FA,
                           img.affine).to_filename('./FA_bmax_{}'.format(bmax))
    nib.nifti1.Nifti1Image(invevalmax, img.affine).to_filename(
        './inv_e1_bmax_{}'.format(bmax))
    nib.nifti1.Nifti1Image(weird_contrast, img.affine).to_filename(
        './minmax_contrast_bmax_{}'.format(bmax))
Exemplo n.º 15
0
def compute_msmt_frf(data,
                     bvals,
                     bvecs,
                     data_dti=None,
                     bvals_dti=None,
                     bvecs_dti=None,
                     mask=None,
                     mask_wm=None,
                     mask_gm=None,
                     mask_csf=None,
                     fa_thr_wm=0.7,
                     fa_thr_gm=0.2,
                     fa_thr_csf=0.1,
                     md_thr_gm=0.0007,
                     md_thr_csf=0.003,
                     min_nvox=300,
                     roi_radii=10,
                     roi_center=None,
                     tol=20,
                     force_b0_threshold=False):
    """Compute a single-shell (under b=1500), single-tissue single Fiber
    Response Function from a DWI volume.
    A DTI fit is made, and voxels containing a single fiber population are
    found using a threshold on the FA.

    Parameters
    ----------
    data : ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals : ndarray
        1D bvals array with shape (N,)
    bvecs : ndarray
        2D bvecs array with shape (N, 3)
    mask : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction.
    mask_wm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary white matter mask. Only the data inside this mask will be used
        to estimate the fiber response function of WM.
    mask_gm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary grey matter mask. Only the data inside this mask will be used
        to estimate the fiber response function of GM.
    mask_csf : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary csf mask. Only the data inside this mask will be used to
        estimate the fiber response function of CSF.
    fa_thr_wm : float, optional
        Use this threshold to select single WM fiber voxels from the FA inside
        the WM mask defined by mask_wm. Each voxel above this threshold will be
        selected. Defaults to 0.7
    fa_thr_gm : float, optional
        Use this threshold to select GM voxels from the FA inside the GM mask
        defined by mask_gm. Each voxel below this threshold will be selected.
        Defaults to 0.2
    fa_thr_csf : float, optional
        Use this threshold to select CSF voxels from the FA inside the CSF mask
        defined by mask_csf. Each voxel below this threshold will be selected.
        Defaults to 0.1
    md_thr_gm : float, optional
        Use this threshold to select GM voxels from the MD inside the GM mask
        defined by mask_gm. Each voxel below this threshold will be selected.
        Defaults to 0.0007
    md_thr_csf : float, optional
        Use this threshold to select CSF voxels from the MD inside the CSF mask
        defined by mask_csf. Each voxel below this threshold will be selected.
        Defaults to 0.003
    min_nvox : int, optional
        Minimal number of voxels needing to be identified as single fiber
        voxels in the automatic estimation. Defaults to 300.
    roi_radii : int or array-like (3,), optional
        Use those radii to select a cuboid roi to estimate the FRF. The roi
        will be a cuboid spanning from the middle of the volume in each
        direction with the different radii. Defaults to 10.
    roi_center : tuple(3), optional
        Use this center to span the roi of size roi_radius (center of the
        3D volume).
    tol : int
        tolerance gap for b-values clustering. Defaults to 20
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    reponses : list of ndarray
        Fiber Response Function of each (3) tissue type, with shape (4, N).
    frf_masks : list of ndarray
        Mask where the frf was calculated, for each (3) tissue type, with
        shape (X, Y, Z).

    Raises
    ------
    ValueError
        If less than `min_nvox` voxels were found with sufficient FA to
        estimate the FRF.
    """
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(force_b0_threshold, bvals.min())

    gtab = gradient_table(bvals, bvecs)

    if data_dti is None and bvals_dti is None and bvecs_dti is None:
        logging.warning(
            "No data specific to DTI was given. If b-values go over 1200, "
            "this might produce wrong results.")
        wm_frf_mask, gm_frf_mask, csf_frf_mask \
            = mask_for_response_msmt(gtab, data,
                                     roi_center=roi_center,
                                     roi_radii=roi_radii,
                                     wm_fa_thr=fa_thr_wm,
                                     gm_fa_thr=fa_thr_gm,
                                     csf_fa_thr=fa_thr_csf,
                                     gm_md_thr=md_thr_gm,
                                     csf_md_thr=md_thr_csf)
    elif data_dti is not None and bvals_dti is not None and bvecs_dti is not None:
        if not is_normalized_bvecs(bvecs_dti):
            logging.warning('Your b-vectors do not seem normalized...')
            bvecs_dti = normalize_bvecs(bvecs_dti)

        check_b0_threshold(force_b0_threshold, bvals_dti.min())
        gtab_dti = gradient_table(bvals_dti, bvecs_dti)

        wm_frf_mask, gm_frf_mask, csf_frf_mask \
            = mask_for_response_msmt(gtab_dti, data_dti,
                                     roi_center=roi_center,
                                     roi_radii=roi_radii,
                                     wm_fa_thr=fa_thr_wm,
                                     gm_fa_thr=fa_thr_gm,
                                     csf_fa_thr=fa_thr_csf,
                                     gm_md_thr=md_thr_gm,
                                     csf_md_thr=md_thr_csf)
    else:
        msg = """Input not valid. Either give no _dti input, or give all
        data_dti, bvals_dti and bvecs_dti."""
        raise ValueError(msg)

    if mask is not None:
        wm_frf_mask *= mask
        gm_frf_mask *= mask
        csf_frf_mask *= mask
    if mask_wm is not None:
        wm_frf_mask *= mask_wm
    if mask_gm is not None:
        gm_frf_mask *= mask_gm
    if mask_csf is not None:
        csf_frf_mask *= mask_csf

    msg = """Could not find at least {0} voxels for the {1} mask. Look at
    previous warnings or be sure that external tissue masks overlap with the
    cuboid ROI."""

    if np.sum(wm_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "WM"))
    if np.sum(gm_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "GM"))
    if np.sum(csf_frf_mask) < min_nvox:
        raise ValueError(msg.format(min_nvox, "CSF"))

    frf_masks = [wm_frf_mask, gm_frf_mask, csf_frf_mask]

    response_wm, response_gm, response_csf \
        = response_from_mask_msmt(gtab, data, wm_frf_mask, gm_frf_mask,
                                  csf_frf_mask, tol=tol)

    responses = [response_wm, response_gm, response_csf]

    return responses, frf_masks
Exemplo n.º 16
0
def compute_ssst_frf(data,
                     bvals,
                     bvecs,
                     mask=None,
                     mask_wm=None,
                     fa_thresh=0.7,
                     min_fa_thresh=0.5,
                     min_nvox=300,
                     roi_radii=10,
                     roi_center=None,
                     force_b0_threshold=False):
    """Compute a single-shell (under b=1500), single-tissue single Fiber
    Response Function from a DWI volume.
    A DTI fit is made, and voxels containing a single fiber population are
    found using a threshold on the FA.

    Parameters
    ----------
    data : ndarray
        4D Input diffusion volume with shape (X, Y, Z, N)
    bvals : ndarray
        1D bvals array with shape (N,)
    bvecs : ndarray
        2D bvecs array with shape (N, 3)
    mask : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary mask. Only the data inside the mask will be used for
        computations and reconstruction. Useful if no white matter mask is
        available.
    mask_wm : ndarray, optional
        3D mask with shape (X,Y,Z)
        Binary white matter mask. Only the data inside this mask and above the
        threshold defined by fa_thresh will be used to estimate the fiber
        response function.
    fa_thresh : float, optional
        Use this threshold as the initial threshold to select single fiber
        voxels. Defaults to 0.7
    min_fa_thresh : float, optional
        Minimal value that will be tried when looking for single fiber voxels.
        Defaults to 0.5
    min_nvox : int, optional
        Minimal number of voxels needing to be identified as single fiber
        voxels in the automatic estimation. Defaults to 300.
    roi_radii : int or array-like (3,), optional
        Use those radii to select a cuboid roi to estimate the FRF. The roi
        will be a cuboid spanning from the middle of the volume in each
        direction with the different radii. Defaults to 10.
    roi_center : tuple(3), optional
        Use this center to span the roi of size roi_radius (center of the
        3D volume).
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.

    Returns
    -------
    full_reponse : ndarray
        Fiber Response Function, with shape (4,)

    Raises
    ------
    ValueError
        If less than `min_nvox` voxels were found with sufficient FA to
        estimate the FRF.
    """
    if min_fa_thresh < 0.4:
        logging.warning(
            "Minimal FA threshold ({:.2f}) seems really small. "
            "Make sure it makes sense for this dataset.".format(min_fa_thresh))

    if not is_normalized_bvecs(bvecs):
        logging.warning("Your b-vectors do not seem normalized...")
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(force_b0_threshold, bvals.min())

    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    if mask is not None:
        data = applymask(data, mask)

    if mask_wm is not None:
        data = applymask(data, mask_wm)
    else:
        logging.warning(
            "No white matter mask specified! Only mask will be used "
            "(if it has been supplied). \nBe *VERY* careful about the "
            "estimation of the fiber response function to ensure no invalid "
            "voxel was used.")

    # Iteratively trying to fit at least min_nvox voxels. Lower the FA threshold
    # when it doesn't work. Fail if the fa threshold is smaller than
    # the min_threshold.
    # We use an epsilon since the -= 0.05 might incur numerical imprecision.
    nvox = 0
    while nvox < min_nvox and fa_thresh >= min_fa_thresh - 0.00001:
        mask = mask_for_response_ssst(gtab,
                                      data,
                                      roi_center=roi_center,
                                      roi_radii=roi_radii,
                                      fa_thr=fa_thresh)
        nvox = np.sum(mask)
        response, ratio = response_from_mask_ssst(gtab, data, mask)

        logging.debug(
            "Number of indices is {:d} with threshold of {:.2f}".format(
                nvox, fa_thresh))
        fa_thresh -= 0.05

    if nvox < min_nvox:
        raise ValueError(
            "Could not find at least {:d} voxels with sufficient FA "
            "to estimate the FRF!".format(min_nvox))

    logging.debug("Found {:d} voxels with FA threshold {:.2f} for "
                  "FRF estimation".format(nvox, fa_thresh + 0.05))
    logging.debug("FRF eigenvalues: {}".format(str(response[0])))
    logging.debug("Ratio for smallest to largest eigen value "
                  "is {:.3f}".format(ratio))
    logging.debug("Mean of the b=0 signal for voxels used "
                  "for FRF: {}".format(response[1]))

    full_response = np.array(
        [response[0][0], response[0][1], response[0][2], response[1]])

    return full_response
Exemplo n.º 17
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.input, args.bvals, args.bvecs])
    assert_outputs_exists(parser, args, [args.frf_file])

    vol = nib.load(args.input)
    data = vol.get_data()

    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    if args.min_fa_thresh < 0.4:
        logging.warn(
            'Minimal FA threshold ({}) seems really small. Make sure it '
            'makes sense for this dataset.'.format(args.min_fa_thresh))

    if args.mask:
        mask = nib.load(args.mask).get_data().astype(np.bool)
        data = applymask(data, mask)

    if args.mask_wm:
        wm_mask = nib.load(args.mask_wm).get_data().astype('bool')
    else:
        wm_mask = np.ones_like(data[..., 0], dtype=np.bool)
        logging.warn(
            'No white matter mask specified! mask_data will be used instead, '
            'if it has been supplied. \nBe *VERY* careful about the '
            'estimation of the fiber response function to ensure no invalid '
            'voxel was used.')

    data_in_wm = applymask(data, wm_mask)

    fa_thresh = args.fa_thresh
    # Iteratively trying to fit at least 300 voxels. Lower the FA threshold
    # when it doesn't work. Fail if the fa threshold is smaller than
    # the min_threshold.
    # We use an epsilon since the -= 0.05 might incurs numerical imprecision.
    nvox = 0
    while nvox < args.min_nvox and fa_thresh >= args.min_fa_thresh - 0.00001:
        response, ratio, nvox = auto_response(gtab,
                                              data_in_wm,
                                              roi_center=args.roi_center,
                                              roi_radius=args.roi_radius,
                                              fa_thr=fa_thresh,
                                              return_number_of_voxels=True)

        logging.debug('Number of indices is %s with threshold of %s', nvox,
                      fa_thresh)
        fa_thresh -= 0.05

    if nvox < args.min_nvox:
        raise ValueError(
            "Could not find at least {} voxels with sufficient FA "
            "to estimate the FRF!".format(args.min_nvox))

    logging.debug("Found %i voxels with FA threshold %f for FRF estimation",
                  nvox, fa_thresh + 0.05)
    logging.debug("FRF eigenvalues: %s", str(response[0]))
    logging.debug("Ratio for smallest to largest eigen value is %f", ratio)
    logging.debug("Mean of the b=0 signal for voxels used for FRF: %f",
                  response[1])

    full_response = np.array(
        [response[0][0], response[0][1], response[0][2], response[1]])

    np.savetxt(args.frf_file, full_response)
Exemplo n.º 18
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    img = nib.load(args.input)
    data = img.get_data()

    print('\ndata shape ({}, {}, {}, {})'.format(data.shape[0], data.shape[1], data.shape[2], data.shape[3]))
    print('total voxels {}'.format(np.prod(data.shape[:3])))

    # remove negatives
    print('\ncliping negative ({} voxels, {:.2f} % of total)'.format((data<0).sum(),100*(data<0).sum()/float(np.prod(data.shape[:3]))))
    data = np.clip(data, 0, np.inf)


    affine = img.affine
    if args.mask is None:
        mask = None
        masksum = np.prod(data.shape[:3])
    else:
        mask = nib.load(args.mask).get_data().astype(np.bool)
        masksum = mask.sum()

    print('\nMask has {} voxels, {:.2f} % of total'.format(masksum,100*masksum/float(np.prod(data.shape[:3]))))

    # Validate bvals and bvecs
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)

    if not is_normalized_bvecs(bvecs):
        print('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)


    # detect unique b-shell and assign shell id to each volume
    # sort bvals to get monotone increasing bvalue
    bvals_argsort = np.argsort(bvals)
    bvals_sorted = bvals[bvals_argsort]

    b_shell_threshold = 25.
    unique_bvalues = []
    shell_idx = []

    unique_bvalues.append(bvals_sorted[0])
    shell_idx.append(0)
    for newb in bvals_sorted[1:]:
        # check if volume is in existing shell
        done = False
        for i,b in enumerate(unique_bvalues):
            if (newb - b_shell_threshold < b) and (newb + b_shell_threshold > b):
                shell_idx.append(i)
                done = True
        if not done:
            unique_bvalues.append(newb)
            shell_idx.append(i+1)

    unique_bvalues = np.array(unique_bvalues)
    # un-sort shells
    shells = np.zeros_like(bvals)
    shells[bvals_argsort] = shell_idx



    print('\nWe have {} shells'.format(len(unique_bvalues)))
    print('with b-values {}\n'.format(unique_bvalues))

    for i in range(len(unique_bvalues)):
        shell_b = bvals[shells==i]
        print('shell {}: n = {}, min/max {} {}'.format(i, len(shell_b), shell_b.min(), shell_b.max()))




    # Get tensors
    method = 'WLS'
    min_signal = 1e-16
    print('\nUsing fitting method {}'.format(method))
    # print('Using minimum signal = {}'.format(min_signal)

    b0_thr = bvals.min() + 10
    print('\nassuming existence of b0 (thr = {})\n'.format(b0_thr))


    mds = []
    for i in range(len(unique_bvalues)-1):
        # max_shell = i+1
        print('fitting using first {} shells (bmax = {})'.format(i+2, bvals[shells==i+1].max()))

        # restricted gtab
        gtab = gradient_table(bvals[shells <= i+1], bvecs[shells <= i+1], b0_threshold=b0_thr)

        tenmodel = TensorModel(gtab, fit_method=method, min_signal=min_signal)

        tenfit = tenmodel.fit(data[..., shells <= i+1], mask)

        evalmax = np.max(tenfit.evals, axis=3)
        evalmin = np.min(tenfit.evals, axis=3)

        evalmax[np.isnan(evalmax)] = 0
        evalmin[np.isnan(evalmin)] = 0
        evalmax[np.isinf(evalmax)] = 0
        evalmin[np.isinf(evalmin)] = 0

        weird_contrast = np.exp(-unique_bvalues[i+1]*evalmin) - np.exp(-unique_bvalues[i+1]*evalmax)

        mds.append(weird_contrast[mask])






    # peaks = []
    oneq = []
    twoq = []
    threeq = []
    th = 0.01
    print('\nonly using values inside quantile [{}, {}] for plotting'.format(th, 1-th))
    for i in range(len(unique_bvalues)-1):
        plt.figure()
        tit = 'exp(-b diff_MIN) - exp(-b diff_MAX), first {} shells (bmax = {})'.format(i+2, bvals[shells==i+1].max())
        print('\nbmax = {}'.format(bvals[shells==i+1].max()))
        # truncate lower and upper MD to remove crazy outliers
        minval = 0
        # maxval = np.quantile(mds[i], 1-th)
        tmp = mds[i]
        # vv1 = tmp.shape[0]
        tmp = tmp[tmp > minval]
        # vv2 = tmp.shape[0]
        print('removed {} zeros'.format(mds[i].shape[0] - tmp.shape[0]))

        # # remove high diffusivity non physical outlier
        # idx1 = (tmp <= 1/3.0e-3) # free water diffusivity at in-vivo brain temperature
        # print('{} voxels above free water diffusivity ({:.2f} % of mask)'.format(idx1.sum(), 100*idx1.sum()/float(masksum)))
        # # remove low diffusivity probable outlier
        # th_diff = 0.05
        # idx2 = (tmp >= 1/(th_diff*1.0e-3)) # 1% of mean diffusivity of in-vivo WM at in-vivo brain temperature
        # print('{} voxels below {} of in-vivo WM diffusivity ({:.2f} % of mask)'.format(idx2.sum(),th_diff, 100*idx2.sum()/float(masksum)))
        # tmp = tmp[np.logical_not(np.logical_or(idx1, idx2))]
        # fit smoothed curve for peak extraction
        # gkde = gaussian_kde(tmp)

        # plt.hist(tmp, bins=100, density=True, color='grey')
        logbins = np.logspace(-2,np.log10(0.7),100)
        plt.hist(tmp, bins=logbins, density=True, color='grey')
        # bs = np.linspace(tmp.min(), tmp.max(), 1000)
        # bs = np.logspace(np.log10(tmp.min()), np.log10(tmp.max()), 1000)
        # smoothed = gkde.pdf(bs)
        # plt.plot(bs, smoothed, color='blue', linewidth=2)
        plt.semilogx([],[])
        # plt.semilogx(bs, smoothed, color='blue', linewidth=2)
        # peak extraction
        # smoothed_peak = bs[smoothed.argmax()]
        # plt.axvline(smoothed_peak, color='red', label='peak ({:.0f})'.format(smoothed_peak))
        # peaks.append(smoothed_peak)
        # useless extra lines
        onequart = np.quantile(tmp, 0.25)
        twoquart = np.quantile(tmp, 0.5)
        threequart = np.quantile(tmp, 0.75)
        oneq.append(onequart)
        twoq.append(twoquart)
        threeq.append(threequart)
        plt.axvline(onequart, color='pink', label='25% ({:.2f})'.format(onequart))
        plt.axvline(twoquart, color='yellow', label='50% ({:.2f})'.format(twoquart))
        plt.axvline(threequart, color='green', label='75% ({:.2f})'.format(threequart))
        plt.title(tit)
        plt.legend(loc=1)

        plt.xlim([0.01,0.7])

        plt.savefig('./dSEst_bmax_{:.0f}.png'.format(unique_bvalues[i]))

    # plt.show()





    # print('\nHigher-than-required bmax will artifactually decrease MD, increasing 1/MD')
    # print('The error on the estimation of 1/MD should be small when the peak is close to bmax')
    # print('This is under the assumption that we have a valid WM mask so that the tissues are somewhat uniform')
    


    bmaxs = np.array([bvals[shells==i+1].max() for i in range(len(unique_bvalues)-1)])


    # plt.figure()
    # plt.plot(bmaxs, peaks, '-x', label = 'fit')
    # plt.plot(bmaxs, bmaxs, label = 'identity')
    # plt.xlabel('bmax')
    # plt.ylabel('MD^-1')
    # plt.legend()
    # plt.title('PEAK')

    # plt.savefig('./bvalEst_peak.png')

    plt.figure()
    plt.grid()
    plt.plot(bmaxs, oneq, '-x', label = 'fit')
    # plt.plot(bmaxs, bmaxs, label = 'identity')
    plt.xlabel('bmax')
    # plt.ylabel('MD^-1')
    # plt.legend()
    # plt.title('25% quartile')

    # plt.savefig('./bvalEst_50Q.png')

    # plt.figure()
    plt.plot(bmaxs, twoq, '-x', label = 'fit')
    # plt.plot(bmaxs, bmaxs, label = 'identity')
    plt.xlabel('bmax')
    # plt.ylabel('MD^-1')
    # plt.legend()
    # plt.title('50% quartile')

    # plt.savefig('./bvalEst_50Q.png')

    # plt.figure()
    plt.plot(bmaxs, threeq, '-x', label = 'fit')
    # plt.plot(bmaxs, bmaxs, label = 'identity')
    plt.xlabel('bmax')
    plt.ylabel('delta S')
    # plt.legend()
    # plt.title('75% quartile')
    plt.title('Quartile')

    plt.savefig('./bvalEst_Qs.png')
Exemplo n.º 19
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec,
                                 args.frf_file])
    assert_outputs_exist(parser, args, args.out_fODF)

    # Loading data
    full_frf = np.loadtxt(args.frf_file)
    vol = nib.load(args.in_dwi)
    data = vol.get_fdata(dtype=np.float32)
    bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec)

    # Checking mask
    if args.mask is None:
        mask = None
    else:
        mask = get_data_as_mask(nib.load(args.mask), dtype=bool)
        if mask.shape != data.shape[:-1]:
            raise ValueError("Mask is not the same shape as data.")

    sh_order = args.sh_order

    # Checking data and sh_order
    b0_thr = check_b0_threshold(
        args.force_b0_threshold, bvals.min(), bvals.min())
    if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2:
        logging.warning(
            'We recommend having at least {} unique DWI volumes, but you '
            'currently have {} volumes. Try lowering the parameter sh_order '
            'in case of non convergence.'.format(
                (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))

    # Checking bvals, bvecs values and loading gtab
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)
    gtab = gradient_table(bvals, bvecs, b0_threshold=b0_thr)

    # Checking full_frf and separating it
    if not full_frf.shape[0] == 4:
        raise ValueError('FRF file did not contain 4 elements. '
                         'Invalid or deprecated FRF format')
    frf = full_frf[0:3]
    mean_b0_val = full_frf[3]

    # Loading the sphere
    reg_sphere = get_sphere('symmetric362')

    # Computing CSD
    csd_model = ConstrainedSphericalDeconvModel(
        gtab, (frf, mean_b0_val),
        reg_sphere=reg_sphere,
        sh_order=sh_order)

    # Computing CSD fit
    csd_fit = fit_from_model(csd_model, data,
                             mask=mask, nbr_processes=args.nbr_processes)

    # Saving results
    shm_coeff = csd_fit.shm_coeff
    if args.sh_basis == 'tournier07':
        shm_coeff = convert_sh_basis(shm_coeff, reg_sphere, mask=mask,
                                     nbr_processes=args.nbr_processes)
    nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32),
                             vol.affine), args.out_fODF)