예제 #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_bval)
    assert_outputs_exist(parser, args, args.out_bval)

    bvals, bvecs = read_bvals_bvecs(args.in_bval, None)
    # Find the volume indices that correspond to the shells to extract.
    tol = args.tolerance

    sorted_centroids, sorted_indices = identify_shells(bvals, tol, sort=True)

    bvals_to_extract = np.sort(args.bvals_to_extract)
    n_shells = np.shape(bvals_to_extract)[0]

    logging.info("number of shells: {}".format(n_shells))
    logging.info("bvals to extract: {}".format(bvals_to_extract))
    logging.info("estimated centroids: {}".format(sorted_centroids))
    logging.info("original bvals: {}".format(bvals))
    logging.info("selected indices: {}".format(sorted_indices))

    new_bvals = bvals
    for i in range(n_shells):
        if np.abs(sorted_centroids[i] - bvals_to_extract[i]) <= tol:
            new_bvals[np.where(sorted_indices == i)] = bvals_to_extract[i]
        else:
            parser.error("No bvals to resample: tolerance is too low.")

    logging.info("new bvals: {}".format(new_bvals))
    new_bvals.shape = (1, len(new_bvals))
    np.savetxt(args.out_bval, new_bvals, '%d')
예제 #2
0
def dwi_protocol(bvals, tol=20):
    """
    Return dwi protocol for each subject

    Parameters
    ----------
    bvals : List
        List of bvals
    tol: int
        tolerance threshold to check
        if the current bval is in the list

    Returns
    -------

    """
    stats_per_subjects = {}
    values_stats = []
    column_names = ["Nbr shells", "Nbr directions"]
    shells = {}
    for i, filename in enumerate(bvals):
        values = []

        bval = np.loadtxt(bvals[i])

        centroids, shells_indices = identify_shells(bval, threshold=tol)
        s_centroids = sorted(centroids)
        values.append(', '.join(str(x) for x in s_centroids))
        values.append(len(shells_indices))
        columns = ["bvals"]
        columns.append("Nbr directions")
        for centroid in s_centroids:
            nearest_centroid = get_nearest_bval(list(shells.keys()), centroid)
            if np.int(nearest_centroid) not in shells:
                shells[np.int(nearest_centroid)] = {}
            nb_directions = len(shells_indices[shells_indices == np.where(
                centroids == centroid)[0]])
            print(centroid, nb_directions)
            if filename not in shells[np.int(nearest_centroid)]:
                shells[np.int(nearest_centroid)][filename] = 0
            shells[np.int(nearest_centroid)][filename] += nb_directions
            values.append(nb_directions)
            columns.append("Nbr bval {}".format(centroid))

        values_stats.append([len(centroids) - 1, len(shells_indices)])

        stats_per_subjects[filename] = pd.DataFrame([values],
                                                    index=[bvals[i]],
                                                    columns=columns)

    stats = pd.DataFrame(values_stats, index=[bvals], columns=column_names)

    stats_across_subjects = pd.DataFrame(
        [stats.mean(), stats.std(),
         stats.min(), stats.max()],
        index=['mean', 'std', 'min', 'max'],
        columns=column_names)

    return stats_per_subjects, stats, stats_across_subjects, shells
예제 #3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.in_mask)
    assert_output_dirs_exist_and_empty(parser, args,
                                       os.path.join(args.out_dir, 'NODDI'),
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Compute NODDI with AMICO on {} shells at found '
                  'at {}.'.format(len(shells_centroids), shells_centroids))

    with redirected_stdout:
        # Load the data
        amico.core.setup()
        ae = amico.Evaluation('.', '.')
        ae.load_data(args.in_dwi,
                     tmp_scheme_filename,
                     mask_filename=args.in_mask)
        # Compute the response functions
        ae.set_model("NODDI")

        intra_vol_frac = np.linspace(0.1, 0.99, 12)
        intra_orient_distr = np.hstack((np.array([0.03, 0.06]),
                                        np.linspace(0.09, 0.99, 10)))

        ae.model.set(args.para_diff, args.iso_diff,
                     intra_vol_frac, intra_orient_distr,
                     False)
        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        out_model_dir = os.path.join(args.out_dir, ae.model.id)
        ae.set_config('OUTPUT_path', out_model_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
def main():
    logger = logging.getLogger("Compute_DKI_Metrics")
    logger.setLevel(logging.INFO)

    parser = _build_args_parser()
    args = parser.parse_args()

    if not args.not_all:
        args.dki_fa = args.dki_fa or 'dki_fa.nii.gz'
        args.dki_md = args.dki_md or 'dki_md.nii.gz'
        args.dki_ad = args.dki_ad or 'dki_ad.nii.gz'
        args.dki_rd = args.dki_rd or 'dki_rd.nii.gz'
        args.mk = args.mk or 'mk.nii.gz'
        args.rk = args.rk or 'rk.nii.gz'
        args.ak = args.ak or 'ak.nii.gz'
        args.dki_residual = args.dki_residual or 'dki_residual.nii.gz'
        args.msk = args.msk or 'msk.nii.gz'
        args.msd = args.msd or 'msd.nii.gz'

    outputs = [args.dki_fa, args.dki_md, args.dki_ad, args.dki_rd,
               args.mk, args.rk, args.ak, args.dki_residual,
               args.msk, args.msd]

    if args.not_all and not any(outputs):
        parser.error('When using --not_all, you need to specify at least ' +
                     'one metric to output.')

    assert_inputs_exist(
        parser, [args.input, args.bvals, args.bvecs], args.mask)
    assert_outputs_exist(parser, args, outputs)

    img = nib.load(args.input)
    data = img.get_fdata()
    affine = img.affine
    if args.mask is None:
        mask = None
    else:
        mask = nib.load(args.mask).get_fdata().astype(np.bool)

    # Validate bvals and bvecs
    bvals, bvecs = read_bvals_bvecs(args.bvals, args.bvecs)
    if not is_normalized_bvecs(bvecs):
        logging.warning('Your b-vectors do not seem normalized...')
        bvecs = normalize_bvecs(bvecs)

    # Find the volume indices that correspond to the shells to extract.
    tol = args.tolerance
    shells, _ = identify_shells(bvals, tol)
    if not len(shells) >= 3:
        parser.error('Data is not multi-shell. You need at least 2 non-zero' +
                     ' b-values')

    if (shells > 2500).any():
        logging.warning('You seem to be using b > 2500 s/mm2 DWI data. ' +
                        'In theory, this is beyond the optimal range for DKI')

    check_b0_threshold(args, bvals.min())
    gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min())

    fwhm = args.smooth
    if fwhm > 0:
        # converting fwhm to Gaussian std
        gauss_std = fwhm / np.sqrt(8 * np.log(2))
        data_smooth = np.zeros(data.shape)
        for v in range(data.shape[-1]):
            data_smooth[..., v] = gaussian_filter(data[..., v],
                                                  sigma=gauss_std)
        data = data_smooth

    # Compute DKI
    dkimodel = dki.DiffusionKurtosisModel(gtab)
    dkifit = dkimodel.fit(data, mask=mask)

    min_k = args.min_k
    max_k = args.max_k

    if args.dki_fa:
        FA = dkifit.fa
        FA[np.isnan(FA)] = 0
        FA = np.clip(FA, 0, 1)

        fa_img = nib.Nifti1Image(FA.astype(np.float32), affine)
        nib.save(fa_img, args.dki_fa)

    if args.dki_md:
        MD = dkifit.md
        md_img = nib.Nifti1Image(MD.astype(np.float32), affine)
        nib.save(md_img, args.dki_md)

    if args.dki_ad:
        AD = dkifit.ad
        ad_img = nib.Nifti1Image(AD.astype(np.float32), affine)
        nib.save(ad_img, args.dki_ad)

    if args.dki_rd:
        RD = dkifit.rd
        rd_img = nib.Nifti1Image(RD.astype(np.float32), affine)
        nib.save(rd_img, args.dki_rd)

    if args.mk:
        MK = dkifit.mk(min_k, max_k)
        mk_img = nib.Nifti1Image(MK.astype(np.float32), affine)
        nib.save(mk_img, args.mk)

    if args.ak:
        AK = dkifit.ak(min_k, max_k)
        ak_img = nib.Nifti1Image(AK.astype(np.float32), affine)
        nib.save(ak_img, args.ak)

    if args.rk:
        RK = dkifit.rk(min_k, max_k)
        rk_img = nib.Nifti1Image(RK.astype(np.float32), affine)
        nib.save(rk_img, args.rk)

    if args.msk or args.msd:
        # Compute MSDKI
        msdki_model = msdki.MeanDiffusionKurtosisModel(gtab)
        msdki_fit = msdki_model.fit(data, mask=mask)

        if args.msk:
            MSK = msdki_fit.msk
            MSK[np.isnan(MSK)] = 0
            MSK = np.clip(MSK, min_k, max_k)

            msk_img = nib.Nifti1Image(MSK.astype(np.float32), affine)
            nib.save(msk_img, args.msk)

        if args.msd:
            MSD = msdki_fit.msd
            msd_img = nib.Nifti1Image(MSD.astype(np.float32), affine)
            nib.save(msd_img, args.msd)

    if args.dki_residual:
        S0 = np.mean(data[..., gtab.b0s_mask], axis=-1)
        data_p = dkifit.predict(gtab, S0)
        R = np.mean(np.abs(data_p[..., ~gtab.b0s_mask] -
                           data[..., ~gtab.b0s_mask]), axis=-1)

        norm = np.linalg.norm(R)
        if norm != 0:
            R = R / norm

        if args.mask is not None:
            R *= mask

        R_img = nib.Nifti1Image(R.astype(np.float32), affine)
        nib.save(R_img, args.dki_residual)
예제 #5
0
def compute_sh_coefficients(dwi,
                            gradient_table,
                            sh_order=4,
                            basis_type='descoteaux07',
                            smooth=0.006,
                            use_attenuation=False,
                            force_b0_threshold=False,
                            mask=None,
                            sphere=None):
    """Fit a diffusion signal with spherical harmonics coefficients.

    Parameters
    ----------
    dwi : nib.Nifti1Image object
        Diffusion signal as weighted images (4D).
    gradient_table : GradientTable
        Dipy object that contains all bvals and bvecs.
    sh_order : int, optional
        SH order to fit, by default 4.
    smooth : float, optional
        Lambda-regularization coefficient in the SH fit, by default 0.006.
    basis_type: str
        Either 'tournier07' or 'descoteaux07'
    use_attenuation: bool, optional
        If true, we will use DWI attenuation. [False]
    force_b0_threshold : bool, optional
        If set, will continue even if the minimum bvalue is suspiciously high.
    mask: nib.Nifti1Image object, optional
        Binary mask. Only data inside the mask will be used for computations
        and reconstruction.
    sphere: Sphere
        Dipy object. If not provided, will use Sphere(xyz=bvecs).

    Returns
    -------
    sh_coeffs : np.ndarray with shape (X, Y, Z, #coeffs)
        Spherical harmonics coefficients at every voxel. The actual number
        of coefficients depends on `sh_order`.
    """

    # Extracting infos
    b0_mask = gradient_table.b0s_mask
    bvecs = gradient_table.bvecs
    bvals = gradient_table.bvals

    # Checks
    if not is_normalized_bvecs(bvecs):
        logging.warning("Your b-vectors do not seem normalized...")
        bvecs = normalize_bvecs(bvecs)
    check_b0_threshold(force_b0_threshold, bvals.min())

    # Ensure that this is on a single shell.
    shell_values, _ = identify_shells(bvals)
    shell_values.sort()
    if force_b0_threshold:
        b0_threshold = bvals.min()
    else:
        b0_threshold = DEFAULT_B0_THRESHOLD
    if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold:
        raise ValueError("Can only work on single shell signals.")

    # Keeping b0-based infos
    bvecs = bvecs[np.logical_not(b0_mask)]
    weights = dwi[..., np.logical_not(b0_mask)]

    # Compute attenuation using the b0.
    if use_attenuation:
        b0 = dwi[..., b0_mask].mean(axis=3)
        weights = compute_dwi_attenuation(weights, b0)

    # Get cartesian coords from bvecs
    if sphere is None:
        sphere = Sphere(xyz=bvecs)

    # Fit SH
    sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth)

    # Apply mask
    if mask is not None:
        sh *= mask[..., None]

    return sh
예제 #6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.mask)

    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug(
        'Compute FreeWater with AMICO on {} shells at found at {}.'.format(
            len(shells_centroids), shells_centroids))

    with redirected_stdout:
        amico.core.setup()
        # Load the data
        ae = amico.Evaluation('.', '.')
        # Load the data
        ae.load_data(args.in_dwi,
                     scheme_filename=tmp_scheme_filename,
                     mask_filename=args.mask)

        # Compute the response functions
        ae.set_model("FreeWater")
        model_type = 'Human'
        if args.mouse:
            model_type = 'Mouse'

        ae.model.set(args.para_diff,
                     np.linspace(args.perp_diff_min, args.perp_diff_max, 10),
                     [args.iso_diff], model_type)

        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        ae.set_config('OUTPUT_path', args.out_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        if args.compute_only:
            return

        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        ae.set_config('doNormalizeSignal', True)
        ae.set_config('doKeepb0Intact', False)
        ae.set_config('doComputeNRMSE', True)
        ae.set_config('doSaveCorrectedDWI', True)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
예제 #7
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    assert_inputs_exist(parser, args.gradient_sampling_file)

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if len(args.gradient_sampling_file) == 2:
        assert_gradients_filenames_valid(parser, args.gradient_sampling_file,
                                         'fsl')
    elif len(args.gradient_sampling_file) == 1:
        basename, ext = os.path.splitext(args.gradient_sampling_file[0])
        if ext in ['.bvec', '.bvecs', '.bvals', '.bval']:
            parser.error('You should input two files for fsl format (.bvec '
                         'and .bval).')
        else:
            assert_gradients_filenames_valid(parser,
                                             args.gradient_sampling_file,
                                             'mrtrix')
    else:
        parser.error('Depending on the gradient format you should have '
                     'two files for FSL format and one file for MRtrix')

    out_basename = None

    proj = args.enable_proj
    each = args.plot_shells

    if not (proj or each):
        parser.error('Select at least one type of rendering (proj or each).')

    if len(args.gradient_sampling_file) == 2:
        gradient_sampling_files = args.gradient_sampling_file
        gradient_sampling_files.sort()  # [bval, bvec]
        # bvecs/bvals (FSL) format, X Y Z AND b (or transpose)
        points = np.genfromtxt(gradient_sampling_files[1])
        if points.shape[0] == 3:
            points = points.T
        bvals = np.genfromtxt(gradient_sampling_files[0])
        centroids, shell_idx = identify_shells(bvals)
    else:
        # MRtrix format X, Y, Z, b
        gradient_sampling_file = args.gradient_sampling_file[0]
        tmp = np.genfromtxt(gradient_sampling_file, delimiter=' ')
        points = tmp[:, :3]
        bvals = tmp[:, 3]
        centroids, shell_idx = identify_shells(bvals)

    if args.out_basename:
        out_basename, ext = os.path.splitext(args.out_basename)
        possible_output_paths = [
            out_basename + '_shell_' + str(i) + '.png' for i in centroids
        ]
        possible_output_paths.append(out_basename + '.png')
        assert_outputs_exist(parser, args, possible_output_paths)

    for idx, b0 in enumerate(centroids[centroids < 40]):
        shell_idx[shell_idx == idx] = -1
        centroids = np.delete(centroids, np.where(centroids == b0))

    if len(shell_idx[shell_idx == -1]) > 0:
        shell_idx[shell_idx != -1] -= 1

    sym = args.enable_sym
    sph = args.enable_sph
    same = args.same_color

    ms = build_ms_from_shell_idx(points, shell_idx)
    if proj:
        plot_proj_shell(ms,
                        use_sym=sym,
                        use_sphere=sph,
                        same_color=same,
                        rad=0.025,
                        opacity=args.opacity,
                        ofile=out_basename,
                        ores=(args.res, args.res))
    if each:
        plot_each_shell(ms,
                        centroids,
                        plot_sym_vecs=sym,
                        use_sphere=sph,
                        same_color=same,
                        rad=0.025,
                        opacity=args.opacity,
                        ofile=out_basename,
                        ores=(args.res, args.res))
예제 #8
0
def compute_snr(dwi,
                bval,
                bvec,
                b0_thr,
                mask,
                noise_mask=None,
                noise_map=None,
                split_shells=False,
                basename=None,
                verbose=False):
    """
    Compute snr

    Parameters
    ----------
    dwi: string
        Path to the dwi file
    bvec: string
        Path to the bvec file
    bval: string
        Path to the bval file
    b0_thr: int
        Threshold to define b0 minimum value
    mask: string
        Path to the mask
    noise_mask: string
        Path to the noise mask
    noise_map: string
        Path to the noise map
    basename: string
        Basename used for naming all output files

    verbose: boolean
        Set to use logging
    """
    if verbose:
        logging.basicConfig(level=logging.INFO)

    img = nib.load(dwi)
    data = img.get_fdata(dtype=np.float32)
    affine = img.affine
    mask = get_data_as_mask(nib.load(mask), dtype=bool)
    bvals, bvecs = read_bvals_bvecs(bval, bvec)

    if split_shells:
        centroids, shell_indices = identify_shells(bvals,
                                                   threshold=40.0,
                                                   roundCentroids=False,
                                                   sort=False)
        bvals = centroids[shell_indices]

    b0s_location = bvals <= b0_thr

    if not np.any(b0s_location):
        raise ValueError('You should ajust --b0_thr={} '
                         'since no b0s where find.'.format(b0_thr))

    if noise_mask is None and noise_map is None:
        b0_mask, noise_mask = median_otsu(data, vol_idx=b0s_location)

        # we inflate the mask, then invert it to recover only the noise
        noise_mask = binary_dilation(noise_mask, iterations=10).squeeze()

        # Add the upper half in order to delete the neck and shoulder
        # when inverting the mask
        noise_mask[..., :noise_mask.shape[-1] // 2] = 1

        # Reverse the mask to get only noise
        noise_mask = (~noise_mask).astype('float32')

        logging.info('Number of voxels found '
                     'in noise mask : {}'.format(np.count_nonzero(noise_mask)))
        logging.info('Total number of voxel '
                     'in volume : {}'.format(np.size(noise_mask)))

        nib.save(nib.Nifti1Image(noise_mask, affine),
                 basename + '_noise_mask.nii.gz')
    elif noise_mask:
        noise_mask = get_data_as_mask(nib.load(noise_mask),
                                      dtype=bool).squeeze()
    elif noise_map:
        img_noisemap = nib.load(noise_map)
        data_noisemap = img_noisemap.get_fdata(dtype=np.float32)

    # Val = np array (mean_signal, std_noise)
    val = {0: {'bvec': [0, 0, 0], 'bval': 0, 'mean': 0, 'std': 0}}
    for idx in range(data.shape[-1]):
        val[idx] = {}
        val[idx]['bvec'] = bvecs[idx]
        val[idx]['bval'] = bvals[idx]
        val[idx]['mean'] = np.mean(data[..., idx:idx + 1][mask > 0])
        if noise_map:
            val[idx]['std'] = np.std(data_noisemap[mask > 0])
        else:
            val[idx]['std'] = np.std(data[..., idx:idx + 1][noise_mask > 0])
            if val[idx]['std'] == 0:
                raise ValueError('Your noise mask does not capture any data'
                                 '(std=0). Please check your noise mask.')

        val[idx]['snr'] = val[idx]['mean'] / val[idx]['std']

    return val
예제 #9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if not len(args.bval) == len(args.bvec):
        parser.error("Not the same number of images in input.")

    all_data = np.concatenate([args.bval, args.bvec])
    assert_inputs_exist(parser, all_data)
    assert_outputs_exist(parser, args, [args.output_report, "data", "libs"])

    if os.path.exists("data"):
        shutil.rmtree("data")
    os.makedirs("data")

    if os.path.exists("libs"):
        shutil.rmtree("libs")

    name = "DWI Protocol"
    summary, stats_for_graph, stats_all, shells = dwi_protocol(args.bval)
    warning_dict = {}
    warning_dict[name] = analyse_qa(stats_for_graph, stats_all,
                                    ["Nbr shells", "Nbr directions"])
    warning_images = [filenames for filenames in warning_dict[name].values()]
    warning_list = np.concatenate(warning_images)
    warning_dict[name]['nb_warnings'] = len(np.unique(warning_list))

    stats_html = dataframe_to_html(stats_all)
    summary_dict = {}
    summary_dict[name] = stats_html

    graphs = []
    graphs.append(
        graph_directions_per_shells("Nbr directions per shell", shells))
    graphs.append(graph_subjects_per_shells("Nbr subjects per shell", shells))
    for c in ["Nbr shells", "Nbr directions"]:
        graph = graph_dwi_protocol(c, c, stats_for_graph)
        graphs.append(graph)

    subjects_dict = {}
    for bval, bvec in zip(args.bval, args.bvec):
        filename = os.path.basename(bval)
        subjects_dict[bval] = {}
        points = np.genfromtxt(bvec)
        if points.shape[0] == 3:
            points = points.T
        bvals = np.genfromtxt(bval)
        centroids, shell_idx = identify_shells(bvals)
        ms = build_ms_from_shell_idx(points, shell_idx)
        plot_proj_shell(ms,
                        centroids,
                        use_sym=True,
                        use_sphere=True,
                        same_color=False,
                        rad=0.025,
                        opacity=0.2,
                        ofile=os.path.join("data", name + filename),
                        ores=(800, 800))
        subjects_dict[bval]['screenshot'] = os.path.join(
            "data", name + filename + '.png')
    metrics_dict = {}
    for subj in args.bval:
        summary_html = dataframe_to_html(summary[subj])
        subjects_dict[subj]['stats'] = summary_html
    metrics_dict[name] = subjects_dict

    nb_subjects = len(args.bval)
    report = Report(args.output_report)
    report.generate(title="Quality Assurance DWI protocol",
                    nb_subjects=nb_subjects,
                    metrics_dict=metrics_dict,
                    summary_dict=summary_dict,
                    graph_array=graphs,
                    warning_dict=warning_dict)
예제 #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec],
        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    if args.commit2:
        if os.path.splitext(args.in_tractogram)[1] != '.h5':
            parser.error('COMMIT2 requires .h5 file for connectomics.')
        args.ball_stick = True

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    hdf5_file = None
    offsets_list = None
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(
                hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        bundle_groups_len = []
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
            streamlines.extend(tmp_streamlines)
            bundle_groups_len.append(len(tmp_streamlines))

        offsets_list = np.cumsum([0] + bundle_groups_len)
        sft = StatefulTractogram(streamlines,
                                 args.in_dwi,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids), shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001 *
                     10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        use_mask = args.in_tracking_mask is not None
        mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/'))
        tol_fun = 1e-2 if args.commit2 else 1e-3
        mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False)
        mit.save_results()
        _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                              'commit_1/', False)

        if args.commit2:
            tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0)
            group_idx = np.array(
                [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)])
            group_w = np.empty_like(bundle_groups_len, dtype=np.float64)
            for k in range(len(bundle_groups_len)):
                group_w[k] = np.sqrt(bundle_groups_len[k]) / \
                    (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12)
            prior_on_bundles = commit.solvers.init_regularisation(
                mit,
                structureIC=group_idx,
                weightsIC=group_w,
                regnorms=[
                    commit.solvers.group_sparsity, commit.solvers.non_negative,
                    commit.solvers.non_negative
                ],
                lambdas=[args.lambda_commit_2, 0.0, 0.0])
            mit.fit(tol_fun=1e-3,
                    max_iter=args.nbr_iter,
                    regularisation=prior_on_bundles,
                    verbose=False)
            mit.save_results()
            _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                                  'commit_2/', True)

    tmp_dir.cleanup()
예제 #11
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()