Ejemplo n.º 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    smoothed_streamlines = []
    for streamline in sft.streamlines:
        if args.gaussian:
            tmp_streamlines = smooth_line_gaussian(streamline, args.gaussian)
        else:
            tmp_streamlines = smooth_line_spline(streamline, args.spline[0],
                                                 args.spline[1])

        if args.error_rate:
            smoothed_streamlines.append(
                compress_streamlines(tmp_streamlines, args.error_rate))

    smoothed_sft = StatefulTractogram.from_sft(
        smoothed_streamlines, sft, data_per_streamline=sft.data_per_streamline)
    save_tractogram(smoothed_sft, args.out_tractogram)
Ejemplo n.º 2
0
def get_streamlines(tracker,
                    mask,
                    seed,
                    chunk_id,
                    pft_tracker,
                    param,
                    compress=False,
                    compression_error_threshold=0.1):
    """
    Generate streamlines from all initial positions
    following the tracking parameters.

    Parameters
    ----------
    tracker : Tracker, tracking object.
    mask : Mask, tracking volume(s).
    seed : Seed, seeding volume.
    chunk_id: int, chunk id.
    pft_tracker: Tracker, tracking object for pft module.
    param: Dict, tracking parameters.
    compress : bool, enable streamlines compression.
    compression_error_threshold : float,
        maximal distance threshold for compression.

    Returns
    -------
    lines: list, list of list of 3D positions
    """

    streamlines = []
    # Initialize the random number generator to cover multiprocessing, skip,
    # which voxel to seed and the subvoxel random position
    chunk_size = int(param['nbr_seeds'] / param['processes'])
    skip = param['skip']

    first_seed_of_chunk = chunk_id * chunk_size + skip
    random_generator, indices = seed.init_pos(param['random'],
                                              first_seed_of_chunk)

    if chunk_id == param['processes'] - 1:
        chunk_size += param['nbr_seeds'] % param['processes']

    for s in xrange(chunk_size):
        if s % 1000 == 0:
            print(str(os.getpid()) + " : " + str(s) + " / " + str(chunk_size))

        pos = seed.get_next_pos(random_generator, indices,
                                first_seed_of_chunk + s)
        line = get_line_from_seed(tracker, mask, pos, pft_tracker, param)
        if line is not None:
            if compress:
                streamlines.append(
                    (compress_streamlines(np.array(line, dtype='float32'),
                                          compression_error_threshold), None,
                     None))
            else:
                streamlines.append((np.array(line,
                                             dtype='float32'), None, None))
    return streamlines
Ejemplo n.º 3
0
def compress_sft(sft, tol_error=0.01):
    """ Compress a stateful tractogram. Uses Dipy's compress_streamlines, but
    deals with space better.

    Dipy's description:
    The compression consists in merging consecutive segments that are
    nearly collinear. The merging is achieved by removing the point the two
    segments have in common.

    The linearization process [Presseau15]_ ensures that every point being
    removed are within a certain margin (in mm) of the resulting streamline.
    Recommendations for setting this margin can be found in [Presseau15]_
    (in which they called it tolerance error).

    The compression also ensures that two consecutive points won't be too far
    from each other (precisely less or equal than `max_segment_length`mm).
    This is a tradeoff to speed up the linearization process [Rheault15]_. A
    low value will result in a faster linearization but low compression,
    whereas a high value will result in a slower linearization but high
    compression.

    [Presseau C. et al., A new compression format for fiber tracking datasets,
    NeuroImage, no 109, 73-83, 2015.]

    Parameters
    ----------
    sft: StatefulTractogram
        The sft to compress.
    tol_error: float (optional)
        Tolerance error in mm (default: 0.01). A rule of thumb is to set it
        to 0.01mm for deterministic streamlines and 0.1mm for probabilitic
        streamlines.

    Returns
    -------
    compressed_sft : StatefulTractogram
    """
    # Go to world space
    orig_space = sft.space
    sft.to_rasmm()

    # Compress streamlines
    compressed_streamlines = compress_streamlines(sft.streamlines,
                                                  tol_error=tol_error)
    if sft.data_per_point is not None:
        logging.warning("Initial StatefulTractogram contained data_per_point. "
                        "This information will not be carried in the final"
                        "tractogram.")

    compressed_sft = StatefulTractogram.from_sft(
        compressed_streamlines,
        sft,
        data_per_streamline=sft.data_per_streamline)

    # Return to original space
    compressed_sft.to_space(orig_space)

    return compressed_sft
Ejemplo n.º 4
0
def get_n_streamlines(tracker,
                      mask,
                      seeding_mask,
                      pft_tracker,
                      param,
                      compress=False,
                      compression_error_threshold=0.1,
                      max_tries=100,
                      save_seeds=True):
    """
    Generate N valid streamlines

    Parameters
    ----------
    tracker : Tracker, tracking object.
    mask : Mask, tracking volume(s).
    seeding_mask : Seed, seeding volume.
    pft_tracker: Tracker, tracking object for pft module.
    param: TrackingParams, tracking parameters.
    compress : bool, enable streamlines compression.
    compression_error_threshold : float,
        maximal distance threshold for compression.

    Returns
    -------
    lines: list, list of list of 3D positions (streamlines)
    """

    i = 0
    streamlines = []
    seeds = []
    skip = 0
    # Initialize the random number generator, skip,
    # which voxel to seed and the subvoxel random position
    first_seed_of_chunk = np.int32(param.skip)
    random_generator, indices =\
        seeding_mask.init_pos(param.random, first_seed_of_chunk)
    while (len(streamlines) < param.nbr_streamlines
           and skip < param.nbr_streamlines * max_tries):
        if i % 1000 == 0:
            logging.info(
                str(os.getpid()) + " : " + str(len(streamlines)) + " / " +
                str(param.nbr_streamlines))
        seed = seeding_mask.get_next_pos(random_generator, indices,
                                         first_seed_of_chunk + i)
        line = get_line_from_seed(tracker, mask, seed, pft_tracker, param)
        if line is not None:
            if compress:
                streamlines.append(
                    compress_streamlines(np.array(line, dtype='float32'),
                                         compression_error_threshold))
            else:
                streamlines.append((np.array(line, dtype='float32')))
            if save_seeds:
                seeds.append(np.asarray(seed, dtype='float32'))

        i += 1
    return streamlines, seeds
Ejemplo n.º 5
0
    def _get_streamlines(self, chunk_id):
        """
        Tracks the n streamlines associates with current process (identified by
        chunk_id). The number n is the total number of seeds / the number of
        processes. If asked by user, may compress the streamlines and save the
        seeds.

        Parameters
        ----------
        chunk_id: int
            This process ID.

        Returns
        -------
        streamlines: list
            The successful streamlines.
        seeds: list
            The list of seeds for each streamline, if self.save_seeds. Else, an
            empty list.
        """
        streamlines = []
        seeds = []

        # Initialize the random number generator to cover multiprocessing,
        # skip, which voxel to seed and the subvoxel random position
        chunk_size = int(self.nbr_seeds / self.nbr_processes)
        first_seed_of_chunk = chunk_id * chunk_size + self.skip
        random_generator, indices = self.seed_generator.init_generator(
            self.rng_seed, first_seed_of_chunk)
        if chunk_id == self.nbr_processes - 1:
            chunk_size += self.nbr_seeds % self.nbr_processes

        # Getting streamlines
        for s in range(chunk_size):
            if s % 1000 == 0:
                logging.info(
                    str(os.getpid()) + " : " + str(s) + " / " +
                    str(chunk_size))

            seed = self.seed_generator.get_next_pos(random_generator, indices,
                                                    first_seed_of_chunk + s)

            # Forward and backward tracking
            line = self._get_line_both_directions(seed)

            if line is not None:
                if self.compression_th and self.compression_th > 0:
                    streamlines.append(
                        compress_streamlines(np.array(line, dtype='float32'),
                                             self.compression_th))
                else:
                    streamlines.append((np.array(line, dtype='float32')))

                if self.save_seeds:
                    seeds.append(np.asarray(seed, dtype='float32'))
        return streamlines, seeds
Ejemplo n.º 6
0
def get_n_streamlines(tracker,
                      mask,
                      seed,
                      pft_tracker,
                      param,
                      compress=False,
                      compression_error_threshold=0.1,
                      max_tries=100):
    """
    Generate N valid streamlines

    Parameters
    ----------
    tracker : Tracker, tracking object.
    mask : Mask, tracking volume(s).
    seed : Seed, seeding volume.
    pft_tracker: Tracker, tracking object for pft module.
    param: Dict, tracking parameters.
    compress : bool, enable streamlines compression.
    compression_error_threshold : float,
        maximal distance threshold for compression.

    Returns
    -------
    lines: list, list of list of 3D positions (streamlines)
    """

    i = 0
    streamlines = []
    skip = 0
    # Initialize the random number generator, skip,
    # which voxel to seed and the subvoxel random position
    random_generator, indices = seed.init_pos(param['random'], param['skip'])
    while (len(streamlines) < param['nbr_streamlines']
           and skip < param['nbr_streamlines'] * max_tries):

        if i % 1000 == 0:
            print(
                str(os.getpid()) + " : " + str(len(streamlines)) + " / " +
                str(param['nbr_streamlines']))

        line = get_line_from_seed(
            tracker, mask,
            seed.get_next_pos(random_generator, indices, param['skip'] + i),
            pft_tracker, param)
        if line is not None:
            if compress:
                streamlines.append(
                    (compress_streamlines(np.array(line, dtype='float32'),
                                          compression_error_threshold), None,
                     None))
            else:
                streamlines.append((np.array(line,
                                             dtype='float32'), None, None))
        i += 1
    return streamlines
Ejemplo n.º 7
0
    def tracks_generator_wrapper():
        for strl, seed in tracker.track():
            # seed must be saved in voxel space, with origin `center`.
            dps = {'seeds': seed - 0.5} if args.save_seeds else {}

            # TODO: Investigate why the streamline must NOT be shifted to
            # origin `corner` for LazyTractogram.
            strl *= voxel_size  # in mm.
            if args.compress:
                strl = compress_streamlines(strl, args.compress)
            yield TractogramItem(strl, dps, {})
Ejemplo n.º 8
0
def compression_wrapper(tract_filename, out_filename, error_rate):
    tracts_format = tc.detect_format(tract_filename)
    tracts_file = tracts_format(tract_filename)

    out_hdr = tracts_file.hdr
    out_format = tc.detect_format(out_filename)
    out_tracts = out_format.create(out_filename, out_hdr)

    for s in tracts_file:
        # TODO we should chunk this.
        out_tracts += np.array(compress_streamlines(list([s]), error_rate))

    out_tracts.close()
Ejemplo n.º 9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_tractogram, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    if args.step_size is not None:
        sft = resample_streamlines_step_size(sft, args.step_size)

    mask_img = nib.load(args.in_mask)
    binary_mask = get_data_as_mask(mask_img)

    if not is_header_compatible(sft, mask_img):
        parser.error('Incompatible header between the tractogram and mask.')

    bundle_disjoint, _ = ndi.label(binary_mask)
    unique, count = np.unique(bundle_disjoint, return_counts=True)
    if args.biggest_blob:
        val = unique[np.argmax(count[1:]) + 1]
        binary_mask[bundle_disjoint != val] = 0
        unique = [0, val]
    if len(unique) == 2:
        logging.info('The provided mask has 1 entity '
                     'cut_outside_of_mask_streamlines function selected.')
        new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
    elif len(unique) == 3:
        logging.info('The provided mask has 2 entity '
                     'cut_between_masks_streamlines function selected.')
        new_sft = cut_between_masks_streamlines(sft, binary_mask)

    else:
        logging.error('The provided mask has more than 2 entities. Cannot cut '
                      'between >2.')
        return

    if len(new_sft) == 0:
        logging.warning('No streamline intersected the provided mask. '
                        'Saving empty tractogram.')
    elif args.error_rate is not None:
        compressed_strs = [
            compress_streamlines(s, args.error_rate)
            for s in new_sft.streamlines
        ]
        new_sft = StatefulTractogram.from_sft(compressed_strs, sft)

    save_tractogram(new_sft, args.out_tractogram)
Ejemplo n.º 10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    for param in ['theta', 'curvature']:
        # Default was removed for consistency.
        if param not in args:
            setattr(args, param, None)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exists(parser, args, [args.output_file])

    np.random.seed(args.seed)

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_data()

    seeds = random_seeds_from_mask(
        nib.load(args.seed_file).get_data(),
        seeds_count=args.nts if 'nts' in args else args.npv,
        seed_count_per_voxel='nts' not in args)

    # Tracking is performed in voxel space
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryTissueClassifier(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=args.step_size,
                                max_cross=1,
                                maxlen=int(args.max_len / args.step_size) + 1,
                                fixedstep=True,
                                return_all=True)

    filtered_streamlines = (s for s in streamlines
                            if args.min_len <= length(s) <= args.max_len)
    if args.compress_streamlines:
        filtered_streamlines = (compress_streamlines(s, args.tolerance_error)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=mask_img.affine)

    # Header with the affine/shape from mask image
    header = {
        Field.VOXEL_TO_RASMM: mask_img.affine.copy(),
        Field.VOXEL_SIZES: mask_img.header.get_zooms(),
        Field.DIMENSIONS: mask_img.shape,
        Field.VOXEL_ORDER: ''.join(aff2axcodes(mask_img.affine))
    }

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 11
0
def compress_streamlines_wrapper(tractogram, error_rate):
    return lambda: [(yield compress_streamlines(s, error_rate))
                    for s in tractogram.streamlines]
Ejemplo n.º 12
0
def get_streamlines(tracker,
                    mask,
                    seeding_mask,
                    chunk_id,
                    pft_tracker,
                    param,
                    compress=False,
                    compression_error_threshold=0.1,
                    save_seeds=True):
    """
    Generate streamlines from all initial positions
    following the tracking parameters.

    Parameters
    ----------
    tracker : Tracker, tracking object.
    mask : Mask, tracking volume(s).
    seeding_mask : Seed, seeding volume.
    chunk_id: int, chunk id.
    pft_tracker: Tracker, tracking object for pft module.
    param: TrackingParams, tracking parameters.
    compress : bool, enable streamlines compression.
    compression_error_threshold : float,
        maximal distance threshold for compression.

    Returns
    -------
    lines: list, list of list of 3D positions
    """

    streamlines = []
    seeds = []
    # Initialize the random number generator to cover multiprocessing, skip,
    # which voxel to seed and the subvoxel random position
    chunk_size = int(param.nbr_seeds / param.processes)
    skip = param.skip

    first_seed_of_chunk = chunk_id * chunk_size + skip
    random_generator, indices =\
        seeding_mask.init_pos(param.random,
                              first_seed_of_chunk)

    if chunk_id == param.processes - 1:
        chunk_size += param.nbr_seeds % param.processes
    for s in range(chunk_size):
        if s % 1000 == 0:
            logging.info(
                str(os.getpid()) + " : " + str(s) + " / " + str(chunk_size))

        seed =\
            seeding_mask.get_next_pos(random_generator,
                                      indices,
                                      first_seed_of_chunk + s)
        line = get_line_from_seed(tracker, mask, seed, pft_tracker, param)
        if line is not None:
            if compress:
                streamlines.append(
                    compress_streamlines(np.array(line, dtype='float32'),
                                         compression_error_threshold))
            else:
                streamlines.append((np.array(line, dtype='float32')))

            if save_seeds:
                seeds.append(np.asarray(seed, dtype='float32'))

    return streamlines, seeds
Ejemplo n.º 13
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [
        args.sh_file, args.seed_file, args.map_include_file,
        args.map_exclude_file
    ])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.sh_file)
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_data().astype(np.double),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.map_include_file)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.get_header()['pixdim'][1:4])

    tissue_classifier = None
    if not args.act:
        tissue_classifier = CmcTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data(),
                                                step_size=args.step_size,
                                                average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActTissueClassifier(map_include_img.get_data(),
                                                map_exclude_img.get_data())

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size
    filtered_streamlines = (
        s for s in pft_streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.isVerbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exist(parser, args, [args.output_file])

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_data()

    # Make sure the mask is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_data(),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryTissueClassifier(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=vox_step_size,
                                max_cross=1,
                                maxlen=max_steps,
                                fixedstep=True,
                                return_all=True,
                                random_seed=args.seed)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    filtered_streamlines = (
        s for s in streamlines
        if scaled_min_length <= length(s) <= scaled_max_length)
    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    header = create_header_from_anat(seed_img, base_filetype=filetype)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 15
0
def get_streamlines(tracker, mask, seed_generator, chunk_id, params,
                    compression_th=0.1, nbr_processes=1,
                    save_seeds=True):
    """
    Generate streamlines from all initial positions following the tracking
    parameters.

    Parameters
    ----------
    tracker : AbstractTracker
        Tracking object.
    mask : BinaryMask
        Tracking volume(s).
    seed_generator : SeedGenerator
        Seeding volume.
    chunk_id: int
        This chunk of seeds id. Chunks sizes depend on the number of processes.
    params: TrackingParams
        Tracking parameters, see scilpy.tracking.utils.py.
    compression_th : float,
        Maximal distance threshold for compression. If None or 0, no
        compression is applied.
    nbr_processes: int
        Number of sub processes to use.
    save_seeds: bool
        Whether to save the seeds associated to their respective streamlines.

    Returns
    -------
    lines: list, list of list of 3D positions
    """

    streamlines = []
    seeds = []
    # Initialize the random number generator to cover multiprocessing, skip,
    # which voxel to seed and the subvoxel random position
    chunk_size = int(params.nbr_seeds / nbr_processes)
    skip = params.skip

    first_seed_of_chunk = chunk_id * chunk_size + skip
    random_generator, indices = \
        seed_generator.init_pos(params.random, first_seed_of_chunk)

    if chunk_id == nbr_processes - 1:
        chunk_size += params.nbr_seeds % nbr_processes
    for s in range(chunk_size):
        if s % 1000 == 0:
            logging.info(str(os.getpid()) + " : " + str(s)
                         + " / " + str(chunk_size))

        seed = \
            seed_generator.get_next_pos(random_generator, indices,
                                        first_seed_of_chunk + s)
        line = get_line_both_directions(tracker, mask, seed, params)
        if line is not None:
            if compression_th and compression_th > 0:
                streamlines.append(
                    compress_streamlines(np.array(line, dtype='float32'),
                                         compression_th))
            else:
                streamlines.append((np.array(line, dtype='float32')))

            if save_seeds:
                seeds.append(np.asarray(seed, dtype='float32'))

    return streamlines, seeds
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.in_odf, args.in_seed, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    mask_img = nib.load(args.in_mask)
    mask_data = get_data_as_mask(mask_img, dtype=bool)

    # Make sure the data is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    odf_sh_img = nib.load(args.in_odf)
    if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
                       odf_sh_img.header.get_zooms()[0], atol=1e-03):
        parser.error(
            'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(dtype=np.float32),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines_generator = LocalTracking(
        _get_direction_getter(args),
        BinaryStoppingCriterion(mask_data),
        seeds, np.eye(4),
        step_size=vox_step_size, max_cross=1,
        maxlen=max_steps,
        fixedstep=True, return_all=True,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines_generator
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines_generator
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)