コード例 #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec],
        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    if args.commit2:
        if os.path.splitext(args.in_tractogram)[1] != '.h5':
            parser.error('COMMIT2 requires .h5 file for connectomics.')
        args.ball_stick = True

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    hdf5_file = None
    offsets_list = None
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(
                hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        bundle_groups_len = []
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
            streamlines.extend(tmp_streamlines)
            bundle_groups_len.append(len(tmp_streamlines))

        offsets_list = np.cumsum([0] + bundle_groups_len)
        sft = StatefulTractogram(streamlines,
                                 args.in_dwi,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids), shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001 *
                     10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        use_mask = args.in_tracking_mask is not None
        mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/'))
        tol_fun = 1e-2 if args.commit2 else 1e-3
        mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False)
        mit.save_results()
        _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                              'commit_1/', False)

        if args.commit2:
            tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0)
            group_idx = np.array(
                [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)])
            group_w = np.empty_like(bundle_groups_len, dtype=np.float64)
            for k in range(len(bundle_groups_len)):
                group_w[k] = np.sqrt(bundle_groups_len[k]) / \
                    (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12)
            prior_on_bundles = commit.solvers.init_regularisation(
                mit,
                structureIC=group_idx,
                weightsIC=group_w,
                regnorms=[
                    commit.solvers.group_sparsity, commit.solvers.non_negative,
                    commit.solvers.non_negative
                ],
                lambdas=[args.lambda_commit_2, 0.0, 0.0])
            mit.fit(tol_fun=1e-3,
                    max_iter=args.nbr_iter,
                    regularisation=prior_on_bundles,
                    verbose=False)
            mit.save_results()
            _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                                  'commit_2/', True)

    tmp_dir.cleanup()
コード例 #2
0
def main():

    parser = build_argparser()
    args = parser.parse_args()

    # check for errors in input
    try:
        dwi = nib.load(args.dwi)
    except:
        parser.error("Expecting DWI as first image")
    # and so on

    # Essential part of COMMIT filtering

    # Create a dictionary for a tractogram
    from commit import trk2dictionary

    trk2dictionary.run(
        filename_trk=args.tracks,  # 'LausanneTwoShell/fibers.trk',
        path_out=args.output,  #'LausanneTwoShell/CommitOutput',
        filename_peaks=args.peaks,  #'LausanneTwoShell/peaks.nii.gz',
        filename_mask=args.wm,  #'LausanneTwoShell/WM.nii.gz',
        fiber_shift=0.5,
        peaks_use_affine=True)

    # Precompute the rotation matrices used internally by COMMIT to create the lookup-tables for the response functions
    import commit
    commit.core.setup()

    # Load the data and fit selected model
    if (args.model_type == 'StickZeppelinBall_Model'):

        mit = commit.Evaluation('.', args.output)

        mit.CONFIG['doNormalizeSignal'] = False
        mit.CONFIG['doDemean'] = False

        mit.load_data(args.dwi, args.dwi_scheme)

        mit.set_model('StickZeppelinBall')
        mit.model.set(1.7E-3, [0.7], [1.7E-3, 3.0E-3])
        mit.generate_kernels(regenerate=True)
        mit.load_kernels()

    if (args.model_type == 'LiFE_Model'):

        mit = commit.Evaluation('.', args.output)

        mit.CONFIG['doNormalizeSignal'] = False
        mit.CONFIG['doDemean'] = True

        mit.load_data(args.dwi, args.dwi_scheme)

        mit.set_model('StickZeppelinBall')
        mit.model.set(1.7E-3, [], [])
        mit.generate_kernels(regenerate=True)
        mit.load_kernels()

    # Load in memory the sparse data-structure previously created with trk2dicitonary.run():
    mit.load_dictionary('.')

    # Now it's time to build the linear operator A to compute the matrix-vector multiplications for solving the linear system. This operator uses information from the segments loaded in the previous step and the lookup-tables for the response functions; it also needs to know the workload to be assigned to each thread durint the multiplications. To this aim, run the following commands:

    mit.set_threads(8)
    mit.build_operator()

    # fit the model (Stick-Zeppelin-Ball in this case) to the data
    mit.fit(tol_fun=1e-3, max_iter=500)

    # Save results
    #mit.save_results()
    if (args.model_type == 'StickZeppelinBall_Model'):
        mit.save_results(
            save_coeff=True)  #make sure it's saved in proper location
    if (args.model_type == 'LiFE_Model'):
        mit.save_results(save_coeff=True)
コード例 #3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()