Esempio n. 1
0
def _afd_rd_wrapper(args):
    in_hdf5_filename = args[0]
    key = args[1]
    fodf_img = args[2]
    sh_basis = args[3]
    length_weighting = args[4]

    with h5py.File(in_hdf5_filename, 'r') as in_hdf5_file:
        affine = in_hdf5_file.attrs['affine']
        dimensions = in_hdf5_file.attrs['dimensions']
        voxel_sizes = in_hdf5_file.attrs['voxel_sizes']
        streamlines = reconstruct_streamlines_from_hdf5(in_hdf5_file, key)
        if len(streamlines) == 0:
            return key, 0

    header = create_nifti_header(affine, dimensions, voxel_sizes)
    sft = StatefulTractogram(streamlines,
                             header,
                             Space.VOX,
                             origin=Origin.TRACKVIS)
    afd_mean_map, rd_mean_map = afd_map_along_streamlines(
        sft, fodf_img, sh_basis, length_weighting)
    afd_mean = np.average(afd_mean_map[afd_mean_map > 0])

    return key, afd_mean
def _average_wrapper(args):
    hdf5_filenames = args[0]
    key = args[1]
    binary = args[2]
    out_dir = args[3]

    hdf5_file_ref = h5py.File(hdf5_filenames[0], 'r')
    affine = hdf5_file_ref.attrs['affine']
    dimensions = hdf5_file_ref.attrs['dimensions']
    density_data = np.zeros(dimensions, dtype=np.float32)
    for hdf5_filename in hdf5_filenames:
        hdf5_file = h5py.File(hdf5_filename, 'r')

        if not (np.allclose(hdf5_file.attrs['affine'], affine)
                and np.allclose(hdf5_file.attrs['dimensions'], dimensions)):
            raise IOError('{} do not have a compatible header'.format(
                hdf5_filename))
        # scil_decompose_connectivity.py saves the streamlines in VOX/CORNER
        streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
        density = compute_tract_counts_map(streamlines, dimensions)
        hdf5_file.close()

        if binary:
            density_data[density > 0] += 1
        elif np.max(density) > 0:
            density_data += density / np.max(density)

    if np.max(density_data) > 0:
        density_data /= len(hdf5_filenames)

        nib.save(nib.Nifti1Image(density_data, affine),
                 os.path.join(out_dir, '{}.nii.gz'.format(key)))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    hdf5_file = h5py.File(args.in_hdf5, 'r')
    for key in hdf5_file.keys():
        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        sft = StatefulTractogram(streamlines,
                                 header,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        if args.include_dps:
            for dps_key in hdf5_file[key].keys():
                if dps_key not in ['data', 'offsets', 'lengths']:
                    sft.data_per_streamline[dps_key] = hdf5_file[key][dps_key]

        save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key)))

    hdf5_file.close()
Esempio n. 4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)
    if args.save_empty and args.labels_list is None:
        parser.error("The option --save_empty requires --labels_list.")

    with h5py.File(args.in_hdf5, 'r') as hdf5_file:
        if args.save_empty:
            all_labels = np.loadtxt(args.labels_list, dtype='str')
            comb_list = list(itertools.combinations(all_labels, r=2))
            comb_list.extend(zip(all_labels, all_labels))
            keys = [i[0] + '_' + i[1] for i in comb_list]
        else:
            keys = hdf5_file.keys()

        if args.edge_keys is not None:
            selected_keys = [key for key in keys if key in args.edge_keys]
        elif args.node_keys is not None:
            selected_keys = []
            for node in args.node_keys:
                selected_keys.extend([
                    key for key in keys
                    if key.startswith(node + '_') or key.endswith('_' + node)
                ])
        else:
            selected_keys = keys

        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        for key in selected_keys:
            streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

            if len(streamlines) == 0 and not args.save_empty:
                continue

            sft = StatefulTractogram(streamlines,
                                     header,
                                     Space.VOX,
                                     origin=Origin.TRACKVIS)
            if args.include_dps:
                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        sft.data_per_streamline[dps_key] = hdf5_file[key][
                            dps_key]

            save_tractogram(sft,
                            '{}.trk'.format(os.path.join(args.out_dir, key)))
Esempio n. 5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_target_file,
                                 args.in_transfo], args.in_deformation)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    with h5py.File(args.in_hdf5, 'r') as in_hdf5_file:
        shutil.copy(args.in_hdf5, args.out_hdf5)
        with h5py.File(args.out_hdf5, 'a') as out_hdf5_file:
            transfo = load_matrix_in_any_format(args.in_transfo)

            deformation_data = None
            if args.in_deformation is not None:
                deformation_data = np.squeeze(nib.load(
                    args.in_deformation).get_fdata(dtype=np.float32))
            target_img = nib.load(args.in_target_file)

            for key in in_hdf5_file.keys():
                affine = in_hdf5_file.attrs['affine']
                dimensions = in_hdf5_file.attrs['dimensions']
                voxel_sizes = in_hdf5_file.attrs['voxel_sizes']
                streamlines = reconstruct_streamlines_from_hdf5(
                    in_hdf5_file, key)

                if len(streamlines) == 0:
                    continue

                header = create_nifti_header(affine, dimensions, voxel_sizes)
                moving_sft = StatefulTractogram(streamlines, header, Space.VOX,
                                                origin=Origin.TRACKVIS)

                new_sft = transform_warp_streamlines(
                    moving_sft, transfo, target_img,
                    inverse=args.inverse,
                    deformation_data=deformation_data,
                    remove_invalid=not args.cut_invalid,
                    cut_invalid=args.cut_invalid)
                new_sft.to_vox()
                new_sft.to_corner()

                affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
                    target_img)
                out_hdf5_file.attrs['affine'] = affine
                out_hdf5_file.attrs['dimensions'] = dimensions
                out_hdf5_file.attrs['voxel_sizes'] = voxel_sizes
                out_hdf5_file.attrs['voxel_order'] = voxel_order

                group = out_hdf5_file[key]
                del group['data']
                group.create_dataset('data',
                                     data=new_sft.streamlines.get_data())
                del group['offsets']
                group.create_dataset('offsets',
                                     data=new_sft.streamlines._offsets)
                del group['lengths']
                group.create_dataset('lengths',
                                     data=new_sft.streamlines._lengths)
Esempio n. 6
0
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]
    min_lesion_vol = args[7]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
    if len(streamlines) == 0:
        return

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy))*voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute or
              'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute and
              'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines,
                                           dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory,
                                      in_label, out_label,
                                      labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        # Maps
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname,
                                       in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple):
            if not isinstance(measure[0], tuple) \
                    and os.path.isfile(measure[0]):
                metric_filename = measure[0]
                metric_img = measure[1]
                if not is_header_compatible(metric_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        metric_filename))
                    raise IOError

                metric_data = metric_img.get_fdata(dtype=np.float64)
                if weighted:
                    avg_value = np.average(metric_data, weights=density)
                else:
                    avg_value = np.average(metric_data[density > 0])
                measures_to_return[metric_filename] = avg_value
            # lesion
            else:
                lesion_filename = measure[0][0]
                computed_lesion_labels = measure[0][1]
                lesion_img = measure[1]
                if not is_header_compatible(lesion_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        lesion_filename))
                    raise IOError

                voxel_sizes = lesion_img.header.get_zooms()[0:3]
                lesion_img.set_filename('tmp.nii.gz')
                lesion_atlas = get_data_as_label(lesion_img)
                tmp_dict = compute_lesion_stats(
                    density.astype(bool), lesion_atlas,
                    voxel_sizes=voxel_sizes, single_label=True,
                    min_lesion_vol=min_lesion_vol,
                    precomputed_lesion_labels=computed_lesion_labels)

                tmp_ind = _streamlines_in_mask(list(streamlines),
                                               lesion_atlas.astype(np.uint8),
                                               np.eye(3), [0, 0, 0])
                streamlines_count = len(
                    np.where(tmp_ind == [0, 1][True])[0].tolist())

                if tmp_dict:
                    measures_to_return[lesion_filename+'vol'] = \
                        tmp_dict['lesion_total_volume']
                    measures_to_return[lesion_filename+'count'] = \
                        tmp_dict['lesion_count']
                    measures_to_return[lesion_filename+'sc'] = \
                        streamlines_count
                else:
                    measures_to_return[lesion_filename+'vol'] = 0
                    measures_to_return[lesion_filename+'count'] = 0
                    measures_to_return[lesion_filename+'sc'] = 0

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                if 'commit' in dps_key:
                    measures_to_return[out_file] = np.sum(
                        hdf5_file[key][dps_key])
                else:
                    measures_to_return[out_file] = np.average(
                        hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy)) * voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute
              or 'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute
              and 'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines, dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory, in_label,
                                      out_label, labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname, in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple) and os.path.isfile(measure[0]):
            metric_filename = measure[0]
            metric_img = measure[1]
            if not is_header_compatible(metric_img, labels_img):
                logging.error('{} do not have a compatible header'.format(
                    metric_filename))
                raise IOError

            metric_data = metric_img.get_fdata(dtype=np.float64)
            if weighted:
                density = density / np.max(density)
                voxels_value = metric_data * density
                voxels_value = voxels_value[voxels_value > 0]
            else:
                voxels_value = metric_data[density > 0]

            measures_to_return[metric_filename] = np.average(voxels_value)

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                measures_to_return[out_file] = np.average(
                    hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
Esempio n. 8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec],
        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    if args.commit2:
        if os.path.splitext(args.in_tractogram)[1] != '.h5':
            parser.error('COMMIT2 requires .h5 file for connectomics.')
        args.ball_stick = True

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    hdf5_file = None
    offsets_list = None
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(
                hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        bundle_groups_len = []
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
            streamlines.extend(tmp_streamlines)
            bundle_groups_len.append(len(tmp_streamlines))

        offsets_list = np.cumsum([0] + bundle_groups_len)
        sft = StatefulTractogram(streamlines,
                                 args.in_dwi,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids), shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001 *
                     10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        use_mask = args.in_tracking_mask is not None
        mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/'))
        tol_fun = 1e-2 if args.commit2 else 1e-3
        mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False)
        mit.save_results()
        _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                              'commit_1/', False)

        if args.commit2:
            tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0)
            group_idx = np.array(
                [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)])
            group_w = np.empty_like(bundle_groups_len, dtype=np.float64)
            for k in range(len(bundle_groups_len)):
                group_w[k] = np.sqrt(bundle_groups_len[k]) / \
                    (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12)
            prior_on_bundles = commit.solvers.init_regularisation(
                mit,
                structureIC=group_idx,
                weightsIC=group_w,
                regnorms=[
                    commit.solvers.group_sparsity, commit.solvers.non_negative,
                    commit.solvers.non_negative
                ],
                lambdas=[args.lambda_commit_2, 0.0, 0.0])
            mit.fit(tol_fun=1e-3,
                    max_iter=args.nbr_iter,
                    regularisation=prior_on_bundles,
                    verbose=False)
            mit.save_results()
            _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                                  'commit_2/', True)

    tmp_dir.cleanup()
Esempio n. 9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()