示例#1
0
def bench_load_trk():
    rng = np.random.RandomState(42)
    dtype = 'float32'
    NB_STREAMLINES = 5000
    NB_POINTS = 1000
    points = [rng.rand(NB_POINTS, 3).astype(dtype)
              for i in range(NB_STREAMLINES)]
    scalars = [rng.rand(NB_POINTS, 10).astype(dtype)
               for i in range(NB_STREAMLINES)]

    repeat = 10

    with InTemporaryDirectory():
        trk_file = "tmp.trk"
        tractogram = Tractogram(points, affine_to_rasmm=np.eye(4))
        TrkFile(tractogram).save(trk_file)

        streamlines_old = [d[0] - 0.5
                           for d in tv.read(trk_file, points_space="rasmm")[0]]
        mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat)
        print("Old: Loaded {:,} streamlines in {:6.2f}".format(NB_STREAMLINES,
                                                               mtime_old))

        trk = nib.streamlines.load(trk_file, lazy_load=False)
        streamlines_new = trk.streamlines
        mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)',
                            repeat)
        print("\nNew: Loaded {:,} streamlines in {:6.2}".format(NB_STREAMLINES,
                                                                mtime_new))
        print("Speedup of {:.2f}".format(mtime_old / mtime_new))
        for s1, s2 in zip(streamlines_new, streamlines_old):
            assert_array_equal(s1, s2)

    # Points and scalars
    with InTemporaryDirectory():

        trk_file = "tmp.trk"
        tractogram = Tractogram(points,
                                data_per_point={'scalars': scalars},
                                affine_to_rasmm=np.eye(4))
        TrkFile(tractogram).save(trk_file)

        streamlines_old = [d[0] - 0.5
                           for d in tv.read(trk_file, points_space="rasmm")[0]]

        scalars_old = [d[1]
                       for d in tv.read(trk_file, points_space="rasmm")[0]]
        mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat)
        msg = "Old: Loaded {:,} streamlines with scalars in {:6.2f}"
        print(msg.format(NB_STREAMLINES, mtime_old))

        trk = nib.streamlines.load(trk_file, lazy_load=False)
        scalars_new = trk.tractogram.data_per_point['scalars']
        mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)',
                            repeat)
        msg = "New: Loaded {:,} streamlines with scalars in {:6.2f}"
        print(msg.format(NB_STREAMLINES, mtime_new))
        print("Speedup of {:2f}".format(mtime_old / mtime_new))
        for s1, s2 in zip(scalars_new, scalars_old):
            assert_array_equal(s1, s2)
示例#2
0
    def _core_run(self, stopping_path, stopping_thr, seeding_path,
                  seed_density, step_size, direction_getter, out_tract,
                  save_seeds):

        stop, affine = load_nifti(stopping_path)
        classifier = ThresholdTissueClassifier(stop, stopping_thr)
        logging.info('classifier done')
        seed_mask, _ = load_nifti(seeding_path)
        seeds = \
            utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
        logging.info('seeds done')

        tracking_result = LocalTracking(direction_getter,
                                        classifier,
                                        seeds,
                                        affine,
                                        step_size=step_size,
                                        save_seeds=save_seeds)

        logging.info('LocalTracking initiated')

        if save_seeds:
            streamlines, seeds = zip(*tracking_result)
            tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
            tractogram.data_per_streamline['seeds'] = seeds
        else:
            tractogram = Tractogram(tracking_result, affine_to_rasmm=np.eye(4))

        save(tractogram, out_tract)
        logging.info('Saved {0}'.format(out_tract))
示例#3
0
    def harvest(
        self,
        states: np.ndarray,
        compress=False,
    ) -> Tuple[StatefulTractogram, np.ndarray]:
        """Internally keep only the streamlines and corresponding env. states
        that haven't stopped yet, and return the streamlines that triggered a
        stopping flag.

        Parameters
        ----------
        states: torch.Tensor
            Environment states to be "pruned"

        Returns
        -------
        tractogram : nib.streamlines.Tractogram
            Tractogram containing the streamlines that stopped tracking,
            along with the stopping_flags information and seeds in
            `tractogram.data_per_streamline`
        states: np.ndarray of size [n_streamlines, input_size]
            Input size for all continuing last streamline positions and
            neighbors + input addons
        stopping_idx: np.ndarray
            Indexes of stopping trajectories. Returned in case an RL
            algorithm would need 'em
        """

        tractogram = Tractogram()

        # Harvest stopped streamlines and associated data
        stopped_seeds = self.starting_points[self.stopping_idx]
        stopped_streamlines = self.streamlines[self.stopping_idx, :self.length]

        # Drop last point if it triggered a flag we don't want
        flags = is_flag_set(self.stopping_flags,
                            StoppingFlags.STOPPING_CURVATURE)
        streamlines = [
            s[:-1] if f else s for f, s in zip(flags, stopped_streamlines)
        ]

        if compress:
            streamlines = compress_streamlines(streamlines, 0.1)

        # Harvested tractogram
        tractogram = Tractogram(streamlines=streamlines,
                                data_per_streamline={
                                    "stopping_flags": self.stopping_flags,
                                    "seeds": stopped_seeds
                                },
                                affine_to_rasmm=self.affine_vox2rasmm)

        # Keep only streamlines that should continue
        states = self._keep(self.continue_idx, states)

        return tractogram, states, self.continue_idx
示例#4
0
def main():
    parser = build_parser()
    args = parser.parse_args()
    print(args)

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error("Cannot find experiment: {0}!".format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading dataset", newline=True):
        volume_manager = VolumeManager()
        dataset = datasets.load_tractography_dataset(
            [args.subject], volume_manager, name="dataset", use_sh_coeffs=hyperparams["use_sh_coeffs"]
        )
        print("Dataset size:", len(dataset))

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression

            model = GRU_Regression.create(experiment_path, volume_manager=volume_manager)
        else:
            raise NameError("Unknown model: {}".format(hyperparams["model"]))

    with Timer("Building evaluation function"):
        loss = loss_factory(hyperparams, model, dataset)
        batch_scheduler = batch_schedulers.TractographyBatchScheduler(
            dataset,
            batch_size=1000,
            noisy_streamlines_sigma=None,
            use_data_augment=False,  # Otherwise it doubles the number of losses :-/
            seed=1234,
            shuffle_streamlines=False,
            normalize_target=hyperparams["normalize"],
        )

        loss_view = views.LossView(loss=loss, batch_scheduler=batch_scheduler)
        losses = loss_view.losses.view()

    with Timer("Saving streamlines"):
        tractogram = Tractogram(dataset.streamlines, affine_to_rasmm=dataset.subjects[0].signal.affine)
        tractogram.data_per_streamline["loss"] = losses
        nib.streamlines.save(tractogram, args.out)
示例#5
0
    def _core_run(self, stopping_path, stopping_thr, seeding_path,
                  seed_density, use_sh, pam, out_tract):

        stop, affine = load_nifti(stopping_path)
        classifier = ThresholdTissueClassifier(stop, stopping_thr)
        logging.info('classifier done')
        seed_mask, _ = load_nifti(seeding_path)
        seeds = \
            utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
        logging.info('seeds done')
        direction_getter = pam

        if use_sh:
            direction_getter = \
                DeterministicMaximumDirectionGetter.from_shcoeff(
                    pam.shm_coeff,
                    max_angle=30.,
                    sphere=pam.sphere)

        streamlines = LocalTracking(direction_getter,
                                    classifier,
                                    seeds,
                                    affine,
                                    step_size=.5)
        logging.info('LocalTracking initiated')

        tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
        save(tractogram, out_tract)

        logging.info('Saved {0}'.format(out_tract))
示例#6
0
    def _core_run(self, stopping_path, stopping_thr, seeding_path,
                  seed_density, step_size, direction_getter, out_tract):

        stop, affine = load_nifti(stopping_path)
        classifier = ThresholdTissueClassifier(stop, stopping_thr)
        logging.info('classifier done')
        seed_mask, _ = load_nifti(seeding_path)
        seeds = \
            utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
        logging.info('seeds done')

        streamlines_generator = LocalTracking(direction_getter,
                                              classifier,
                                              seeds,
                                              affine,
                                              step_size=step_size)
        logging.info('LocalTracking initiated')

        tractogram = Tractogram(streamlines_generator,
                                affine_to_rasmm=np.eye(4))
        save(tractogram, out_tract)
        logging.info('Saved {0}'.format(out_tract))
示例#7
0
def get_seeds_from_wm(wm_path, threshold=0):

    wm_file = nib.load(wm_path)
    wm_img = wm_file.get_fdata()

    seeds = np.argwhere(wm_img > threshold)
    seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

    seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

    n_seeds = len(seeds)

    header = TrkFile.create_empty_header()

    header["voxel_to_rasmm"] = wm_file.affine
    header["dimensions"] = wm_file.header["dim"][1:4]
    header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
    header["voxel_order"] = get_reference_info(wm_file)[3]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk")

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)
示例#8
0
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram])
    assert_outputs_exists(parser, args, args.out_tractogram)

    tractogram_file = load(args.in_tractogram)
    streamlines = list(tractogram_file.streamlines)

    data_per_point = tractogram_file.tractogram.data_per_point
    data_per_streamline = tractogram_file.tractogram.data_per_streamline

    new_streamlines, new_per_point, new_per_streamline = get_subset_streamlines(
                                                       streamlines,
                                                       data_per_point,
                                                       data_per_streamline,
                                                       args.max_num_streamlines,
                                                       args.seed)

    new_tractogram = Tractogram(new_streamlines,
                                data_per_point=new_per_point,
                                data_per_streamline=new_per_streamline,
                                affine_to_rasmm=np.eye(4))

    save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
示例#9
0
    def harvest(self):
        undone, done, stopping_flags = self.is_stopping(
            self.sprouts, self.sprouts_stop)

        # Do not keep last point since it almost surely raised the stopping flag.
        streamlines = list(self.sprouts[done, :-1])
        if self.compress_streamlines:
            streamlines = compress_streamlines(streamlines)

        tractogram = Tractogram(
            streamlines=streamlines,
            data_per_streamline={"stopping_flags": stopping_flags})

        # Keep only undone sprouts
        self._keep(undone)

        return tractogram
示例#10
0
def save_from_voxel_space(streamlines, anat, ref_tracts, out_name):
    if isinstance(ref_tracts, six.string_types):
        nib_object = nib.streamlines.load(ref_tracts, lazy_load=True)
    else:
        nib_object = ref_tracts

    if isinstance(anat, six.string_types):
        anat = nib.load(anat)

    affine_to_rasmm = get_affine_trackvis_to_rasmm(nib_object.header)

    tracto = Tractogram(streamlines=streamlines,
                        affine_to_rasmm=affine_to_rasmm)

    spacing = anat.header['pixdim'][1:4]
    tracto.streamlines._data *= spacing

    nib.streamlines.save(tracto, out_name, header=nib_object.header)
示例#11
0
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram])
    assert_outputs_exist(parser, args, args.out_tractogram)

    tractogram_file = load(args.in_tractogram)
    streamlines = list(tractogram_file.streamlines)

    new_streamlines = resample_streamlines(streamlines,
                                           args.nb_pts_per_streamline,
                                           args.arclength)

    new_tractogram = Tractogram(
        new_streamlines,
        data_per_streamline=tractogram_file.tractogram.data_per_streamline,
        affine_to_rasmm=np.eye(4))

    save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
示例#12
0
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    tractogram_file = load(args.in_tractogram)
    streamlines = list(tractogram_file.streamlines)

    data_per_point = tractogram_file.tractogram.data_per_point
    data_per_streamline = tractogram_file.tractogram.data_per_streamline

    new_streamlines, new_per_point, new_per_streamline = filter_streamlines_by_length(
        streamlines, data_per_point, data_per_streamline, args.minL, args.maxL)

    new_tractogram = Tractogram(new_streamlines,
                                data_per_streamline=new_per_streamline,
                                data_per_point=new_per_point,
                                affine_to_rasmm=np.eye(4))

    save(new_tractogram, args.out_tractogram, header=tractogram_file.header)
示例#13
0
    def render(
        self,
        tractogram: Tractogram = None,
        filename: str = None
    ):
        """ Render the streamlines, either directly or through a file
        Might render from "outside" the environment, like for comet

        Parameters:
        -----------
        tractogram: Tractogram, optional
            Object containing the streamlines and seeds
        path: str, optional
            If set, save the image at the specified location instead
            of displaying directly
        """
        from fury import window, actor
        # Might be rendering from outside the environment
        if tractogram is None:
            tractogram = Tractogram(
                streamlines=self.streamlines[:, :self.length],
                data_per_streamline={
                    'seeds': self.starting_points
                })

        # Reshape peaks for displaying
        X, Y, Z, M = self.peaks.data.shape
        peaks = np.reshape(self.peaks.data, (X, Y, Z, 5, M//5))

        # Setup scene and actors
        scene = window.Scene()

        stream_actor = actor.streamtube(tractogram.streamlines)
        peak_actor = actor.peak_slicer(peaks,
                                       np.ones((X, Y, Z, M)),
                                       colors=(0.2, 0.2, 1.),
                                       opacity=0.5)
        dot_actor = actor.dots(tractogram.data_per_streamline['seeds'],
                               color=(1, 1, 1),
                               opacity=1,
                               dot_size=2.5)
        scene.add(stream_actor)
        scene.add(peak_actor)
        scene.add(dot_actor)
        scene.reset_camera_tight(0.95)

        # Save or display scene
        if filename is not None:
            directory = os.path.dirname(pjoin(self.experiment_path, 'render'))
            if not os.path.exists(directory):
                os.makedirs(directory)
            dest = pjoin(directory, filename)
            window.snapshot(
                scene,
                fname=dest,
                offscreen=True,
                size=(800, 800))
        else:
            showm = window.ShowManager(scene, reset_camera=True)
            showm.initialize()
            showm.start()
示例#14
0
    def run(self,
            pam_files,
            wm_files,
            gm_files,
            csf_files,
            seeding_files,
            step_size=0.2,
            seed_density=1,
            pmf_threshold=0.1,
            max_angle=20.,
            pft_back=2,
            pft_front=1,
            pft_count=15,
            out_dir='',
            out_tractogram='tractogram.trk',
            save_seeds=False):
        """Workflow for Particle Filtering Tracking.

        This workflow use a saved peaks and metrics (PAM) file as input.

        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path to white matter partial volume estimate for tracking (CMC).
        gm_files : string
            Path to grey matter partial volume estimate for tracking (CMC).
        csf_files : string
            Path to cerebrospinal fluid partial volume estimate for tracking
            (CMC).
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size used for tracking (default 0.2mm).
        seed_density : int, optional
            Number of seeds per dimension inside voxel (default 1).
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions (default 0.1).
        max_angle : float, optional
            Maximum angle between streamline segments (range [0, 90],
            default 20).
        pft_back : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography (defaul 2mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_front : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance (default 1mm). The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_count : int, optional
            Number of particles to use in the particle filter (default 15).
        out_dir : string, optional
           Output directory (default input file directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved (default 'tractogram.trk')
        save_seeds : bool, optional
            If true, save the seeds associated to their streamline
            in the 'data_per_streamline' Tractogram dictionary using
            'seeds' as the key

        References
        ----------
        Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards
        quantitative connectivity analysis: reducing tractography biases.
        NeuroImage, 98, 266-278, 2014.

        """
        io_it = self.get_io_iterator()

        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \
                in io_it:

            logging.info(
                'Particle Filtering tracking on {0}'.format(pams_path))

            pam = load_peaks(pams_path, verbose=False)

            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            classifier = CmcTissueClassifier.from_pve(wm,
                                                      gm,
                                                      csf,
                                                      step_size=step_size,
                                                      average_voxel_size=avs)
            logging.info('classifier done')
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
            logging.info('seeds done')
            dg = ProbabilisticDirectionGetter

            direction_getter = dg.from_shcoeff(pam.shm_coeff,
                                               max_angle=max_angle,
                                               sphere=pam.sphere,
                                               pmf_threshold=pmf_threshold)

            tracking_result = ParticleFilteringTracking(
                direction_getter,
                classifier,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=pft_back,
                pft_front_tracking_dist=pft_front,
                pft_max_trial=20,
                particle_count=pft_count,
                save_seeds=save_seeds)

            logging.info('ParticleFilteringTracking initiated')

            if save_seeds:
                streamlines, seeds = zip(*tracking_result)
                tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
                tractogram.data_per_streamline['seeds'] = seeds
            else:
                tractogram = Tractogram(tracking_result,
                                        affine_to_rasmm=np.eye(4))

            save(tractogram, out_tract)

            logging.info('Saved {0}'.format(out_tract))
from nibabel.orientations import aff2axcodes

import numpy as np

streamlines = np.load(
    '/media/localadmin/HagmannHDD/Seb/testPFT/diffusion_preproc_resampled_streamlines.npy'
)

imref = nb.load('/media/localadmin/HagmannHDD/Seb/testPFT/shore_gfa.nii.gz')

affine = imref.affine.copy()

print(imref.affine.copy())
print(affine)

header = {}
header[Field.ORIGIN] = affine[:3, 3]
header[Field.VOXEL_TO_RASMM] = affine
header[Field.VOXEL_SIZES] = imref.header.get_zooms()[:3]
header[Field.DIMENSIONS] = imref.shape[:3]
header[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine))

for i, streamline in enumerate(streamlines):
    for j, voxel in enumerate(streamline):
        streamlines[i][j] = streamlines[i][j] - imref.affine.copy()[:3, 3]

print(header[Field.VOXEL_ORDER])
tractogram = Tractogram(streamlines=streamlines, affine_to_rasmm=affine)
out_fname = '/media/localadmin/HagmannHDD/Seb/testPFT/track_nib1.trk'
nb.streamlines.save(tractogram, out_fname, header=header)
示例#16
0
def main():

    parser = build_args_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if os.path.isfile(args.output):
        if args.force:
            logging.info('Overwriting {0}.'.format(args.output))
        else:
            parser.error('{0} already exist! Use -f to overwrite it.'.format(
                args.output))

    # Load all input streamlines.
    data = [load_data(f) for f in args.inputs]
    streamlines, data_per_streamline, data_per_point = zip(*data)
    nb_streamlines = [len(s) for s in streamlines]

    # Apply the requested operation to each input file.
    logging.info('Performing operation \'{}\'.'.format(args.operation))
    new_streamlines, indices = perform_streamlines_operation(
        OPERATIONS[args.operation], streamlines, args.precision)

    # Get the meta data of the streamlines.
    new_data_per_streamline = {}
    new_data_per_point = {}
    if not args.no_data:

        for key in data_per_streamline[0].keys():
            all_data = np.vstack([s[key] for s in data_per_streamline])
            new_data_per_streamline[key] = all_data[indices, :]

        # Add the indices to the metadata if requested.
        if args.save_meta_indices:
            new_data_per_streamline['ids'] = indices

        for key in data_per_point[0].keys():
            all_data = list(chain(*[s[key] for s in data_per_point]))
            new_data_per_point[key] = [all_data[i] for i in indices]

    # Save the indices to a file if requested.
    if args.save_indices is not None:
        start = 0
        indices_dict = {'filenames': args.inputs}
        for name, nb in zip(args.inputs, nb_streamlines):
            end = start + nb
            file_indices = \
                [i - start for i in indices if i >= start and i < end]
            indices_dict[name] = file_indices
            start = end
        with open(args.save_indices, 'wt') as f:
            json.dump(indices_dict, f)

    # Save the new streamlines.
    logging.info('Saving streamlines to {0}.'.format(args.output))
    reference_file = load(args.inputs[0], True)
    new_tractogram = Tractogram(new_streamlines,
                                data_per_streamline=new_data_per_streamline,
                                data_per_point=new_data_per_point)

    # If the reference is a .tck, the affine will be None.
    affine = reference_file.tractogram.affine_to_rasmm
    if affine is None:
        affine = np.eye(4)
    new_tractogram.affine_to_rasmm = affine

    new_header = reference_file.header.copy()
    new_header['nb_streamlines'] = len(new_streamlines)
    save(new_tractogram, args.output, header=new_header)
示例#17
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()
示例#18
0
def auto_extract_VCs(streamlines, ref_bundles):
    # Streamlines = list of all streamlines

    VC = 0
    VC_idx = set()

    found_vbs_info = {}
    for bundle in ref_bundles:
        found_vbs_info[bundle['name']] = {
            'nb_streamlines': 0,
            'streamlines_indices': set()
        }

    # Need to bookkeep because we chunk for big datasets
    processed_strl_count = 0
    chunk_size = 5000
    chunk_it = 0

    nb_bundles = len(ref_bundles)
    bundles_found = [False] * nb_bundles

    logging.debug("Starting scoring VCs")

    qb = QuickBundles(threshold=20, metric=AveragePointwiseEuclideanMetric())

    # Start loop here for big datasets
    while processed_strl_count < len(streamlines):
        logging.debug("Starting chunk: {0}".format(chunk_it))

        strl_chunk = streamlines[chunk_it * chunk_size:(chunk_it + 1) *
                                 chunk_size]

        processed_strl_count += len(strl_chunk)
        cur_chunk_VC_idx, cur_chunk_IC_idx, cur_chunk_VCWP_idx = set(), set(
        ), set()

        # Already resample and run quickbundles on the submission chunk,
        # to avoid doing it at every call of auto_extract
        rstreamlines = set_number_of_points(strl_chunk, NB_POINTS_RESAMPLE)

        # qb.cluster had problem with f8
        rstreamlines = [s.astype('f4') for s in rstreamlines]

        chunk_cluster_map = qb.cluster(rstreamlines)
        chunk_cluster_map.refdata = strl_chunk

        logging.debug("Starting VC identification through auto_extract")

        for bundle_idx, ref_bundle in enumerate(ref_bundles):
            # The selected indices are from [0, len(strl_chunk)]
            selected_streamlines_indices = auto_extract(
                ref_bundle['cluster_map'],
                chunk_cluster_map,
                clean_thr=ref_bundle['threshold'])

            # Remove duplicates, when streamlines are assigned to multiple VBs.
            selected_streamlines_indices = set(selected_streamlines_indices) - \
                                           cur_chunk_VC_idx
            cur_chunk_VC_idx |= selected_streamlines_indices

            nb_selected_streamlines = len(selected_streamlines_indices)

            if nb_selected_streamlines:
                bundles_found[bundle_idx] = True
                VC += nb_selected_streamlines

                # Shift indices to match the real number of streamlines
                global_select_strl_indices = set([
                    v + chunk_it * chunk_size
                    for v in selected_streamlines_indices
                ])
                vb_info = found_vbs_info.get(ref_bundle['name'])
                vb_info['nb_streamlines'] += nb_selected_streamlines
                vb_info['streamlines_indices'] |= global_select_strl_indices

                VC_idx |= global_select_strl_indices
            else:
                global_select_strl_indices = set()

        chunk_it += 1

    # Compute bundle overlap, overreach and f1_scores and update found_vbs_info
    for bundle_idx, ref_bundle in enumerate(ref_bundles):
        bundle_name = ref_bundle["name"]
        bundle_mask = ref_bundle["mask"]

        vb_info = found_vbs_info[bundle_name]

        # Streamlines are in voxel space since that's how they were
        # loaded in the scoring function.
        tractogram = Tractogram(
            streamlines=(streamlines[i]
                         for i in vb_info['streamlines_indices']),
            affine_to_rasmm=bundle_mask.affine)

        scores = {}
        if len(tractogram) > 0:
            scores = compute_bundle_coverage_scores(tractogram, bundle_mask)

        vb_info['overlap'] = scores.get("OL", 0)
        vb_info['overreach'] = scores.get("OR", 0)
        vb_info['overreach_norm'] = scores.get("ORn", 0)
        vb_info['f1_score'] = scores.get("F1", 0)

    return VC_idx, found_vbs_info, bundles_found
示例#19
0
def dwi_dipy_run(dwi_dir,
                 node_size,
                 dir_path,
                 conn_model,
                 parc,
                 atlas_select,
                 network,
                 wm_mask=None):
    from dipy.reconst.dti import TensorModel, quantize_evecs
    from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, recursive_response
    from dipy.tracking.local import LocalTracking, ActTissueClassifier
    from dipy.tracking import utils
    from dipy.direction import peaks_from_model
    from dipy.tracking.eudx import EuDX
    from dipy.data import get_sphere, default_sphere
    from dipy.core.gradients import gradient_table
    from dipy.io import read_bvals_bvecs
    from dipy.tracking.streamline import Streamlines
    from dipy.direction import ProbabilisticDirectionGetter, ClosestPeakDirectionGetter, BootDirectionGetter
    from nibabel.streamlines import save as save_trk
    from nibabel.streamlines import Tractogram

    ##
    dwi_dir = '/Users/PSYC-dap3463/Downloads/bedpostx_s002'
    img_pve_csf = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_vent_csf_diff_dwi.nii.gz'
    )
    img_pve_wm = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_wm_in_dwi_bin.nii.gz'
    )
    img_pve_gm = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/t1w_gm_mask_dwi.nii.gz'
    )
    labels_img = nib.load(
        '/Users/PSYC-dap3463/Downloads/002_all/tmp/reg_a/dwi_aligned_atlas.nii.gz'
    )
    num_total_samples = 10000
    tracking_method = 'boot'  # Options are 'boot', 'prob', 'peaks', 'closest'
    procmem = [2, 4]
    ##

    if parc is True:
        node_size = 'parc'

    dwi_img = "%s%s" % (dwi_dir, '/dwi.nii.gz')
    nodif_brain_mask_path = "%s%s" % (dwi_dir, '/nodif_brain_mask.nii.gz')
    bvals = "%s%s" % (dwi_dir, '/bval')
    bvecs = "%s%s" % (dwi_dir, '/bvec')

    dwi_img = nib.load(dwi_img)
    data = dwi_img.get_data()
    [bvals, bvecs] = read_bvals_bvecs(bvals, bvecs)
    gtab = gradient_table(bvals, bvecs)
    gtab.b0_threshold = min(bvals)
    sphere = get_sphere('symmetric724')

    # Loads mask and ensures it's a true binary mask
    mask_img = nib.load(nodif_brain_mask_path)
    mask = mask_img.get_data()
    mask = mask > 0

    # Fit a basic tensor model first
    model = TensorModel(gtab)
    ten = model.fit(data, mask)
    fa = ten.fa

    # Tractography
    if conn_model == 'csd':
        print('Tracking with csd model...')
    elif conn_model == 'tensor':
        print('Tracking with tensor model...')
    else:
        raise RuntimeError("%s%s" % (conn_model, ' is not a valid model.'))

    # Combine seed counts from voxel with seed counts total
    wm_mask_data = img_pve_wm.get_data()
    wm_mask_data[0, :, :] = False
    wm_mask_data[:, 0, :] = False
    wm_mask_data[:, :, 0] = False
    seeds = utils.seeds_from_mask(wm_mask_data,
                                  density=1,
                                  affine=dwi_img.get_affine())
    seeds_rnd = utils.random_seeds_from_mask(ten.fa > 0.02,
                                             seeds_count=num_total_samples,
                                             seed_count_per_voxel=True)
    seeds_all = np.vstack([seeds, seeds_rnd])

    # Load tissue maps and prepare tissue classifier (Anatomically-Constrained Tractography (ACT))
    background = np.ones(img_pve_gm.shape)
    background[(img_pve_gm.get_data() + img_pve_wm.get_data() +
                img_pve_csf.get_data()) > 0] = 0
    include_map = img_pve_gm.get_data()
    include_map[background > 0] = 1
    exclude_map = img_pve_csf.get_data()
    act_classifier = ActTissueClassifier(include_map, exclude_map)

    if conn_model == 'tensor':
        ind = quantize_evecs(ten.evecs, sphere.vertices)
        streamline_generator = EuDX(a=fa,
                                    ind=ind,
                                    seeds=seeds_all,
                                    odf_vertices=sphere.vertices,
                                    a_low=0.05,
                                    step_sz=.5)
    elif conn_model == 'csd':
        print('Tracking with CSD model...')
        response = recursive_response(
            gtab,
            data,
            mask=img_pve_wm.get_data().astype('bool'),
            sh_order=8,
            peak_thr=0.01,
            init_fa=0.05,
            init_trace=0.0021,
            iter=8,
            convergence=0.001,
            parallel=True)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response)
        if tracking_method == 'boot':
            dg = BootDirectionGetter.from_data(data,
                                               csd_model,
                                               max_angle=30.,
                                               sphere=default_sphere)
        elif tracking_method == 'prob':
            try:
                print(
                    'First attempting to build the direction getter directly from the spherical harmonic representation of the FOD...'
                )
                csd_fit = csd_model.fit(
                    data, mask=img_pve_wm.get_data().astype('bool'))
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    csd_fit.shm_coeff, max_angle=30., sphere=default_sphere)
            except:
                print(
                    'Sphereical harmonic not available for this model. Using peaks_from_model to represent the ODF of the model on a spherical harmonic basis instead...'
                )
                peaks = peaks_from_model(
                    csd_model,
                    data,
                    default_sphere,
                    .5,
                    25,
                    mask=img_pve_wm.get_data().astype('bool'),
                    return_sh=True,
                    parallel=True,
                    nbr_processes=procmem[0])
                dg = ProbabilisticDirectionGetter.from_shcoeff(
                    peaks.shm_coeff, max_angle=30., sphere=default_sphere)
        elif tracking_method == 'peaks':
            dg = peaks_from_model(model=csd_model,
                                  data=data,
                                  sphere=default_sphere,
                                  relative_peak_threshold=.5,
                                  min_separation_angle=25,
                                  mask=img_pve_wm.get_data().astype('bool'),
                                  parallel=True,
                                  nbr_processes=procmem[0])
        elif tracking_method == 'closest':
            csd_fit = csd_model.fit(data,
                                    mask=img_pve_wm.get_data().astype('bool'))
            pmf = csd_fit.odf(default_sphere).clip(min=0)
            dg = ClosestPeakDirectionGetter.from_pmf(pmf,
                                                     max_angle=30.,
                                                     sphere=default_sphere)
        streamline_generator = LocalTracking(dg,
                                             act_classifier,
                                             seeds_all,
                                             affine=dwi_img.affine,
                                             step_size=0.5)
        del dg
        try:
            del csd_fit
        except:
            pass
        try:
            del response
        except:
            pass
        try:
            del csd_model
        except:
            pass
        streamlines = Streamlines(streamline_generator, buffer_size=512)

    save_trk(Tractogram(streamlines, affine_to_rasmm=dwi_img.affine),
             'prob_streamlines.trk')
    tracks = [sl for sl in streamlines if len(sl) > 1]
    labels_data = labels_img.get_data().astype('int')
    labels_affine = labels_img.affine
    conn_matrix, grouping = utils.connectivity_matrix(
        tracks,
        labels_data,
        affine=labels_affine,
        return_mapping=True,
        mapping_as_streamlines=True,
        symmetric=True)
    conn_matrix[:3, :] = 0
    conn_matrix[:, :3] = 0

    return conn_matrix
示例#20
0
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel):

    trk_dir = os.path.join(data_dir, "bundles")

    if source in ["wm", "trk"]:
        anat_path = os.path.join(data_dir, "masks", "wm.nii.gz")
        resized_path = os.path.join(data_dir, "masks",
                                    "wm_{}.nii.gz".format(voxel))
    elif source == "brain":
        anat_path = os.path.join("subjects", "ismrm_gt",
                                 "dwi_brain_mask.nii.gz")
        resized_path = os.path.join("subjects", "ismrm_gt",
                                    "dwi_brain_mask_125.nii.gz")

    sp.call([
        "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path,
        resized_path
    ])

    if source == "trk":

        print("Running Tractconverter...")
        sp.call([
            "python", "tractconverter/scripts/WalkingTractConverter.py", "-i",
            trk_dir, "-a", resized_path, "-vtk2trk"
        ])

        print("Loading seed bundles...")
        seed_bundles = []
        for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir,
                                                            "*.trk"))):
            trk_file = nib.streamlines.load(trk_path)
            endpoints = []
            for fiber in trk_file.tractogram.streamlines:
                endpoints.append(fiber[0])
                endpoints.append(fiber[-1])
            seed_bundles.append(endpoints)
            if i == 0:
                header = trk_file.header

        n_seeds = sum([len(b) for b in seed_bundles])
        n_bundles = len(seed_bundles)

        print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles))

        seeds = np.array([[seed] for bundle in seed_bundles
                          for seed in bundle])

        if keep < 1:
            if weighted:
                p = np.zeros(n_seeds)
                offset = 0
                for b in seed_bundles:
                    l = len(b)
                    p[offset:offset + l] = 1 / (l * n_bundles)
                    offset += l
            else:
                p = np.ones(n_seeds) / n_seeds

    elif source in ["brain", "wm"]:

        weighted = False

        wm_file = nib.load(resized_path)
        wm_img = wm_file.get_fdata()

        seeds = np.argwhere(wm_img > threshold)
        seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

        seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

        n_seeds = len(seeds)

        if keep < 1:
            p = np.ones(n_seeds) / n_seeds

        header = TrkFile.create_empty_header()

        header["voxel_to_rasmm"] = wm_file.affine
        header["dimensions"] = wm_file.header["dim"][1:4]
        header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
        header["voxel_order"] = get_reference_info(wm_file)[3]

    if keep < 1:
        keep_n = int(keep * n_seeds)
        print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n))
        np.random.seed(42)
        keep_idx = np.random.choice(len(seeds),
                                    size=keep_n,
                                    replace=False,
                                    p=p)
        seeds = seeds[keep_idx]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_dir = os.path.join(data_dir, "seeds")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk")
    save_path = save_path.format(
        source, "W" + str(int(100 * keep)) if weighted else "all", voxel)

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)

    os.remove(resized_path)
    for file in glob.glob(os.path.join(trk_dir, "*.trk")):
        os.remove(file)
示例#21
0
def main():
    parser = build_argparser()
    args = parser.parse_args()

    # Get experiment folder
    experiment_path = args.name
    if not os.path.isdir(experiment_path):
        # If not a directory, it must be the name of the experiment.
        experiment_path = pjoin(".", "experiments", args.name)

    if not os.path.isdir(experiment_path):
        parser.error('Cannot find experiment: {0}!'.format(args.name))

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "hyperparams.json"))
    except FileNotFoundError:
        hyperparams = smartutils.load_dict_from_json_file(
            pjoin(experiment_path, "..", "hyperparams.json"))

    with Timer("Loading DWIs"):
        # Load gradients table
        dwi_name = args.dwi
        if dwi_name.endswith(".gz"):
            dwi_name = dwi_name[:-3]
        if dwi_name.endswith(".nii"):
            dwi_name = dwi_name[:-4]

        try:
            bvals_filename = dwi_name + ".bvals"
            bvecs_filename = dwi_name + ".bvecs"
            bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                bvals_filename, bvecs_filename)
        except FileNotFoundError:
            try:
                bvals_filename = dwi_name + ".bval"
                bvecs_filename = dwi_name + ".bvec"
                bvals, bvecs = dipy.io.gradients.read_bvals_bvecs(
                    bvals_filename, bvecs_filename)
            except FileNotFoundError as e:
                print("Could not find .bvals/.bvecs or .bval/.bvec files...")
                raise e

        dwi = nib.load(args.dwi)
        if hyperparams["use_sh_coeffs"]:
            # Use 45 spherical harmonic coefficients to represent the diffusion signal.
            weights = neurotools.get_spherical_harmonics_coefficients(
                dwi, bvals, bvecs).astype(np.float32)
        else:
            # Resample the diffusion signal to have 100 directions.
            weights = neurotools.resample_dwi(dwi, bvals,
                                              bvecs).astype(np.float32)

        affine_rasmm2dwivox = np.linalg.inv(dwi.affine)

    with Timer("Loading model"):
        if hyperparams["model"] == "gru_regression":
            from learn2track.models import GRU_Regression
            model_class = GRU_Regression
        elif hyperparams['model'] == 'gru_gaussian':
            from learn2track.models import GRU_Gaussian
            model_class = GRU_Gaussian
        elif hyperparams['model'] == 'gru_mixture':
            from learn2track.models import GRU_Mixture
            model_class = GRU_Mixture
        elif hyperparams['model'] == 'gru_multistep':
            from learn2track.models import GRU_Multistep_Gaussian
            model_class = GRU_Multistep_Gaussian
        elif hyperparams['model'] == 'ffnn_regression':
            from learn2track.models import FFNN_Regression
            model_class = FFNN_Regression
        else:
            raise ValueError("Unknown model!")

        kwargs = {}
        volume_manager = neurotools.VolumeManager()
        volume_manager.register(weights)
        kwargs['volume_manager'] = volume_manager

        # Load the actual model.
        model = model_class.create(
            pjoin(experiment_path),
            **kwargs)  # Create new instance and restore model.
        model.drop_prob = 0.
        print(str(model))

    mask = None
    if args.mask is not None:
        with Timer("Loading mask"):
            mask_nii = nib.load(args.mask)
            mask = mask_nii.get_data()
            # Compute the affine allowing to evaluate the mask at some coordinates correctly.

            # affine_maskvox2dwivox = mask_vox => rasmm space => dwi_vox
            affine_maskvox2dwivox = np.dot(affine_rasmm2dwivox,
                                           mask_nii.affine)
            if args.dilate_mask:
                import scipy
                mask = scipy.ndimage.morphology.binary_dilation(mask).astype(
                    mask.dtype)

    with Timer("Generating seeds"):
        seeds = []

        for filename in args.seeds:
            if filename.endswith('.trk') or filename.endswith('.tck'):
                tfile = nib.streamlines.load(filename)
                # Send the streamlines to voxel since that's where we'll track.
                tfile.tractogram.apply_affine(affine_rasmm2dwivox)

                # Use extremities of the streamlines as seeding points.
                seeds += [s[0] for s in tfile.streamlines]
                seeds += [s[-1] for s in tfile.streamlines]

            else:
                # Assume it is a binary mask.
                rng = np.random.RandomState(args.seeding_rng_seed)
                nii_seeds = nib.load(filename)

                # affine_seedsvox2dwivox = mask_vox => rasmm space => dwi_vox
                affine_seedsvox2dwivox = np.dot(affine_rasmm2dwivox,
                                                nii_seeds.affine)

                nii_seeds_data = nii_seeds.get_data()

                if args.dilate_seeding_mask:
                    import scipy
                    nii_seeds_data = scipy.ndimage.morphology.binary_dilation(
                        nii_seeds_data).astype(nii_seeds_data.dtype)

                indices = np.array(np.where(nii_seeds_data)).T
                for idx in indices:
                    seeds_in_voxel = idx + rng.uniform(
                        -0.5, 0.5, size=(args.nb_seeds_per_voxel, 3))
                    seeds_in_voxel = nib.affines.apply_affine(
                        affine_seedsvox2dwivox, seeds_in_voxel)
                    seeds.extend(seeds_in_voxel)

        seeds = np.array(seeds, dtype=theano.config.floatX)

    with Timer("Tracking in the diffusion voxel space"):
        voxel_sizes = np.asarray(dwi.header.get_zooms()[:3])
        if not np.all(voxel_sizes == dwi.header.get_zooms()[0]):
            print("* Careful voxel are anisotropic {}!".format(
                tuple(voxel_sizes)))
        # Since we are tracking in diffusion voxel space, convert step_size (in mm) to voxel.

        if args.step_size is not None:
            step_size = np.float32(args.step_size / voxel_sizes.max())
            # Also convert max length (in mm) to voxel.
            max_nb_points = int(np.ceil(args.max_length / args.step_size))
        else:
            step_size = None
            max_nb_points = args.max_length

        if args.theta is not None:
            theta = np.deg2rad(args.theta)
        elif args.curvature is not None and args.curvature > 0:
            theta = get_max_angle_from_curvature(args.curvature, step_size)
        else:
            theta = np.deg2rad(45)

        print("Angle: {}".format(np.rad2deg(theta)))
        print("Step size (vox): {}".format(step_size))
        print("Max nb. points: {}".format(max_nb_points))

        is_outside_mask = make_is_outside_mask(mask,
                                               affine_maskvox2dwivox,
                                               threshold=args.mask_threshold)
        is_too_long = make_is_too_long(max_nb_points)
        is_too_curvy = make_is_too_curvy(np.rad2deg(theta))
        is_unlikely = make_is_unlikely(0.5)
        is_stopping = make_is_stopping({
            STOPPING_MASK: is_outside_mask,
            STOPPING_LENGTH: is_too_long,
            STOPPING_CURVATURE: is_too_curvy,
            STOPPING_LIKELIHOOD: is_unlikely
        })

        is_stopping.max_nb_points = max_nb_points  # Small hack

        tractogram = batch_track(model,
                                 weights,
                                 seeds,
                                 step_size=step_size,
                                 is_stopping=is_stopping,
                                 batch_size=args.batch_size,
                                 args=args)

        # Streamlines have been generated in voxel space.
        # Transform them them back to RAS+mm space using the dwi's affine.
        tractogram.affine_to_rasmm = dwi.affine
        tractogram.to_world()  # Performed in-place.

    nb_streamlines = len(tractogram)

    if args.save_rejected:
        rejected_tractogram = Tractogram()
        rejected_tractogram.affine_to_rasmm = tractogram._affine_to_rasmm

    print("Generated {:,} (compressed) streamlines".format(nb_streamlines))
    with Timer("Cleaning streamlines", newline=True):
        # Flush streamlines that have no points.
        if args.save_rejected:
            rejected_tractogram += tractogram[
                np.array(list(map(len, tractogram))) <= 0]

        tractogram = tractogram[np.array(list(map(len, tractogram))) > 0]
        print("Removed {:,} empty streamlines".format(nb_streamlines -
                                                      len(tractogram)))

        # Remove small streamlines
        nb_streamlines = len(tractogram)
        lengths = dipy.tracking.streamline.length(tractogram.streamlines)

        if args.save_rejected:
            rejected_tractogram += tractogram[lengths < args.min_length]

        tractogram = tractogram[lengths >= args.min_length]
        lengths = lengths[lengths >= args.min_length]
        if len(lengths) > 0:
            print("Average length: {:.2f} mm.".format(lengths.mean()))
            print("Minimum length: {:.2f} mm. Maximum length: {:.2f}".format(
                lengths.min(), lengths.max()))
        print("Removed {:,} streamlines smaller than {:.2f} mm".format(
            nb_streamlines - len(tractogram), args.min_length))
        if args.discard_stopped_by_curvature:
            nb_streamlines = len(tractogram)
            stopping_curvature_flag_is_set = is_flag_set(
                tractogram.data_per_streamline['stopping_flags'][:, 0],
                STOPPING_CURVATURE)

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    stopping_curvature_flag_is_set]

            tractogram = tractogram[np.logical_not(
                stopping_curvature_flag_is_set)]
            print(
                "Removed {:,} streamlines stopped for having a curvature higher than {:.2f} degree"
                .format(nb_streamlines - len(tractogram), np.rad2deg(theta)))

        if args.filter_threshold is not None:
            # Remove streamlines that produces a reconstruction error higher than a certain threshold.
            nb_streamlines = len(tractogram)
            losses = compute_loss_errors(tractogram.streamlines, model,
                                         hyperparams)
            print("Mean loss: {:.4f} ± {:.4f}".format(
                np.mean(losses),
                np.std(losses, ddof=1) / np.sqrt(len(losses))))

            if args.save_rejected:
                rejected_tractogram += tractogram[
                    losses > args.filter_threshold]

            tractogram = tractogram[losses <= args.filter_threshold]
            print(
                "Removed {:,} streamlines producing a loss lower than {:.2f} mm"
                .format(nb_streamlines - len(tractogram),
                        args.filter_threshold))

    with Timer("Saving {:,} (compressed) streamlines".format(len(tractogram))):
        filename = args.out
        if args.out is None:
            prefix = args.prefix
            if prefix is None:
                dwi_name = os.path.basename(args.dwi)
                if dwi_name.endswith(".nii.gz"):
                    dwi_name = dwi_name[:-7]
                else:  # .nii
                    dwi_name = dwi_name[:-4]

                prefix = os.path.basename(os.path.dirname(args.dwi)) + dwi_name
                prefix = prefix.replace(".", "_")

            seed_mask_type = args.seeds[0].replace(".", "_").replace(
                "_", "").replace("/", "-")
            if "int" in args.seeds[0]:
                seed_mask_type = "int"
            elif "wm" in args.seeds[0]:
                seed_mask_type = "wm"
            elif "rois" in args.seeds[0]:
                seed_mask_type = "rois"
            elif "bundles" in args.seeds[0]:
                seed_mask_type = "bundles"

            mask_type = ""
            if "fa" in args.mask:
                mask_type = "fa"
            elif "wm" in args.mask:
                mask_type = "wm"

            if args.dilate_seeding_mask:
                seed_mask_type += "D"

            if args.dilate_mask:
                mask_type += "D"

            filename_items = [
                "{}",
                "useMaxComponent-{}",
                # "seed-{}",
                # "mask-{}",
                "step-{:.2f}mm",
                "nbSeeds-{}",
                "maxAngleDeg-{:.1f}"
                # "keepCurv-{}",
                # "filtered-{}",
                # "minLen-{}",
                # "pftRetry-{}",
                # "pftHist-{}",
                # "trackLikePeter-{}",
            ]
            filename = ('_'.join(filename_items) + ".tck").format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        save_path = pjoin(experiment_path, filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(save_path))
        except:
            pass

        print("Saving to {}".format(save_path))
        nib.streamlines.save(tractogram, save_path)

    if args.save_rejected:
        with Timer("Saving {:,} (compressed) rejected streamlines".format(
                len(rejected_tractogram))):
            rejected_filename_items = filename_items.copy()
            rejected_filename_items.insert(1, "rejected")
            rejected_filename = (
                '_'.join(rejected_filename_items) + ".tck"
            ).format(
                prefix,
                args.use_max_component,
                # seed_mask_type,
                # mask_type,
                args.step_size,
                args.nb_seeds_per_voxel,
                np.rad2deg(theta)
                # not args.discard_stopped_by_curvature,
                # args.filter_threshold,
                # args.min_length,
                # args.pft_nb_retry,
                # args.pft_nb_backtrack_steps,
                # args.track_like_peter
            )

        rejected_save_path = pjoin(experiment_path, rejected_filename)
        try:  # Create dirs, if needed.
            os.makedirs(os.path.dirname(rejected_save_path))
        except:
            pass

        print("Saving rejected streamlines to {}".format(rejected_save_path))
        nib.streamlines.save(rejected_tractogram, rejected_save_path)
示例#22
0
    def run(self,
            pam_files,
            wm_files,
            gm_files,
            csf_files,
            seeding_files,
            step_size=0.2,
            back_tracking_dist=2,
            front_tracking_dist=1,
            max_trial=20,
            particle_count=15,
            seed_density=1,
            pmf_threshold=0.1,
            max_angle=30.,
            out_dir='',
            out_tractogram='tractogram.trk'):
        """Workflow for Particle Filtering Tracking.

        This workflow use a saved peaks and metrics (PAM) file as input.

        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path of White matter for stopping criteria for tracking.
        gm_files : string
            Path of grey matter for stopping criteria for tracking.
        csf_files : string
            Path of cerebrospinal fluid for stopping criteria for tracking.
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size used for tracking.
        back_tracking_dist : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography. The total particle filtering tractography distance is
            equal to back_tracking_dist + front_tracking_dist.
            By default this is set to 2 mm.
        front_tracking_dist : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance. The total particle filtering tractography
            distance is equal to back_tracking_dist + front_tracking_dist. By
            default this is set to 1 mm.
        max_trial : int, optional
            Maximum number of trial for the particle filtering tractography
            (Prevents infinite loops, default=20).
        particle_count : int, optional
            Number of particles to use in the particle filter. (default 15)
        seed_density : int, optional
            Number of seeds per dimension inside voxel (default 1).
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions. (default 0.1)
        max_angle : float, optional
            Maximum angle between tract segments. This angle can be more
            generous (larger) than values typically used with probabilistic
            direction getters. The angle range is (0, 90)
        out_dir : string, optional
           Output directory (default input file directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved (default 'tractogram.trk')

        References
        ----------
        Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M.
               Towards quantitative connectivity analysis: reducing
               tractography biases. NeuroImage, 98, 266-278, 2014..

        """
        io_it = self.get_io_iterator()

        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract \
                in io_it:

            logging.info(
                'Particle Filtering tracking on {0}'.format(pams_path))

            pam = load_peaks(pams_path, verbose=False)

            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            classifier = CmcTissueClassifier.from_pve(wm,
                                                      gm,
                                                      csf,
                                                      step_size=step_size,
                                                      average_voxel_size=avs)
            logging.info('classifier done')
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask,
                density=[seed_density, seed_density, seed_density],
                affine=affine)
            logging.info('seeds done')
            dg = ProbabilisticDirectionGetter

            direction_getter = dg.from_shcoeff(pam.shm_coeff,
                                               max_angle=max_angle,
                                               sphere=pam.sphere,
                                               pmf_threshold=pmf_threshold)

            streamlines = ParticleFilteringTracking(
                direction_getter,
                classifier,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=back_tracking_dist,
                pft_front_tracking_dist=front_tracking_dist,
                pft_max_trial=max_trial,
                particle_count=particle_count)

            logging.info('ParticleFilteringTracking initiated')

            tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
            save(tractogram, out_tract)

            logging.info('Saved {0}'.format(out_tract))
示例#23
0
"""
.. figure:: det_streamlines.png
 :align: center

 **Deterministic streamlines using EuDX (new framework)**

To learn more about this process you could start playing with the number of
seed points or, even better, specify seeds to be in specific regions of interest
in the brain.

Save the resulting streamlines in a Trackvis (.trk) format and FA as
Nifti1 (.nii.gz).
"""

save_trk(Tractogram(streamlines, affine_to_rasmm=img.affine),
         'det_streamlines.trk')

save_nifti('fa_map.nii.gz', fa, img.affine)

"""
In Windows if you get a runtime error about frozen executable please start
your script by adding your code above in a ``main`` function and use:

``
if __name__ == '__main__':
    import multiprocessing
    multiprocessing.freeze_support()
    main()
``
示例#24
0
 def run_test(self):
     tracto = Tractogram(streamlines=[self.strl_data, self.strl_data])
     indices_new = uncompress(tracto.streamlines)
     self.assertTrue(np.allclose(indices_new[0], self.gt_indices))
示例#25
0
def main():
    start = time.time()

    with open('config.json') as config_json:
        config = json.load(config_json)

    # Load the data
    dmri_image = nib.load(config['data_file'])
    dmri = dmri_image.get_data()
    affine = dmri_image.affine
    #aparc_im = nib.load(config['freesurfer'])
    aparc_im = nib.load('volume.nii.gz')
    aparc = aparc_im.get_data()
    end = time.time()
    print('Loaded Files: ' + str((end - start)))
    print(dmri.shape)
    print(aparc.shape)

    # Create the white matter and callosal masks
    start = time.time()
    wm_regions = [
        2, 41, 16, 17, 28, 60, 51, 53, 12, 52, 12, 52, 13, 18, 54, 50, 11, 251,
        252, 253, 254, 255, 10, 49, 46, 7
    ]

    wm_mask = np.zeros(aparc.shape)
    for l in wm_regions:
        wm_mask[aparc == l] = 1
    #np.save('wm_mask',wm_mask)
    #p = os.getcwd()+'wm.json'
    #json.dump(wm_mask, codecs.open(p, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4)
    #with open('wm_mask.txt', 'wb') as wm:
    #np.savetxt('wm.txt', wm_mask, fmt='%5s')
    #print(wm_mask)
    # Create the gradient table from the bvals and bvecs

    bvals, bvecs = read_bvals_bvecs(config['data_bval'], config['data_bvec'])

    gtab = gradient_table(bvals, bvecs, b0_threshold=100)
    end = time.time()
    print('Created Gradient Table: ' + str((end - start)))

    ##The probabilistic model##
    """
    # Use the Constant Solid Angle (CSA) to find the Orientation Dist. Function
    # Helps orient the wm tracts
    start = time.time()
    csa_model = CsaOdfModel(gtab, sh_order=6)
    csa_peaks = peaks_from_model(csa_model, dmri, default_sphere,
                                 relative_peak_threshold=.8,
                                 min_separation_angle=45,
                                 mask=wm_mask)
    print('Creating CSA Model: ' + str(time.time() - start))
    """
    # Use the SHORE model to find Orientation Dist. Function
    start = time.time()
    shore_model = ShoreModel(gtab)
    shore_peaks = peaks_from_model(shore_model,
                                   dmri,
                                   default_sphere,
                                   relative_peak_threshold=.8,
                                   min_separation_angle=45,
                                   mask=wm_mask)
    print('Creating Shore Model: ' + str(time.time() - start))

    # Begins the seed in the wm tracts
    seeds = utils.seeds_from_mask(wm_mask, density=[1, 1, 1], affine=affine)
    print('Created White Matter seeds: ' + str(time.time() - start))

    # Create a CSD model to measure Fiber Orientation Dist
    print('Begin the probabilistic model')

    response, ratio = auto_response(gtab, dmri, roi_radius=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
    csd_fit = csd_model.fit(data=dmri, mask=wm_mask)
    print('Created the CSD model: ' + str(time.time() - start))

    # Set the Direction Getter to randomly choose directions

    prob_dg = ProbabilisticDirectionGetter.from_shcoeff(csd_fit.shm_coeff,
                                                        max_angle=30.,
                                                        sphere=default_sphere)
    print('Created the Direction Getter: ' + str(time.time() - start))

    # Restrict the white matter tracking
    classifier = ThresholdTissueClassifier(shore_peaks.gfa, .25)

    print('Created the Tissue Classifier: ' + str(time.time() - start))

    # Create the probabilistic model
    streamlines = LocalTracking(prob_dg,
                                tissue_classifier=classifier,
                                seeds=seeds,
                                step_size=.5,
                                max_cross=1,
                                affine=affine)
    print('Created the probabilistic model: ' + str(time.time() - start))

    # Compute streamlines and store as a list.
    streamlines = list(streamlines)
    print('Computed streamlines: ' + str(time.time() - start))

    #from dipy.tracking.streamline import transform_streamlines
    #streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

    # Create a tractogram from the streamlines and save it
    tractogram = Tractogram(streamlines, affine_to_rasmm=affine)
    save(tractogram, 'track.tck')
    end = time.time()
    print("Created the tck file: " + str((end - start)))
示例#26
0
    def clean_tractogram(self, tractogram, affine_vox2mask):
        """
        Remove potential "non-connections" by filtering according to
        curvature, length and mask

        Parameters:
        -----------
        tractogram: Tractogram
            Full tractogram

        Returns:
        --------
        tractogram: Tractogram
            Filtered tractogram
        """
        print('Cleaning tractogram ... ', end='', flush=True)

        streamlines = tractogram.streamlines

        # # Filter by curvature
        # dirty_mask = is_flag_set(
        #     stopping_flags, StoppingFlags.STOPPING_CURVATURE)
        dirty_mask = np.zeros(len(streamlines))

        # Filter by length unless the streamline ends in GM
        # Example case: Bundle 3 of fibercup tends to be shorter than 35

        lengths = [slength(s) for s in streamlines]
        short_lengths = np.asarray([lgt <= self.min_length for lgt in lengths])

        dirty_mask = np.logical_or(short_lengths, dirty_mask)

        long_lengths = np.asarray([lgt > 200. for lgt in lengths])

        dirty_mask = np.logical_or(long_lengths, dirty_mask)

        # start_mask = is_inside_mask(
        #     np.asarray([s[0] for s in streamlines])[:, None],
        #     self.target_mask.data, affine_vox2mask, 0.5)

        # assert(np.any(start_mask))

        # end_mask = is_inside_mask(
        #     np.asarray([s[-1] for s in streamlines])[:, None],
        #     self.target_mask.data, affine_vox2mask, 0.5)

        # assert(np.any(end_mask))

        # mask_mask = np.logical_not(np.logical_and(start_mask, end_mask))

        # dirty_mask = np.logical_or(
        #     dirty_mask,
        #     mask_mask)

        # Filter by loops
        # For example: A streamline ending and starting in the same roi
        looping_mask = np.array([winding(s) for s in streamlines]) > 330
        dirty_mask = np.logical_or(dirty_mask, looping_mask)

        # Only keep valid streamlines
        valid_indices = np.nonzero(np.logical_not(dirty_mask))
        clean_streamlines = streamlines[valid_indices]
        clean_dps = tractogram.data_per_streamline[valid_indices]
        print('Done !')

        print('Kept {}/{} streamlines'.format(len(valid_indices[0]),
                                              len(streamlines)))

        return Tractogram(clean_streamlines, data_per_streamline=clean_dps)