Exemplo n.º 1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, [], optional=args.out_centroids)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_clusters_dir,
                                       create_dir=True)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines = sft.streamlines
    thresholds = [40, 30, 20, args.dist_thresh]
    clusters = qbx_and_merge(streamlines,
                             thresholds,
                             nb_pts=args.nb_points,
                             verbose=False)

    for i, cluster in enumerate(clusters):
        if len(cluster.indices) > 1:
            cluster_streamlines = itemgetter(*cluster.indices)(streamlines)
        else:
            cluster_streamlines = streamlines[cluster.indices]

        new_sft = StatefulTractogram.from_sft(cluster_streamlines, sft)
        save_tractogram(
            new_sft,
            os.path.join(args.out_clusters_dir, 'cluster_{}.trk'.format(i)))

    if args.out_centroids:
        new_sft = StatefulTractogram.from_sft(clusters.centroids, sft)
        save_tractogram(new_sft, args.out_centroids)
Exemplo n.º 2
0
def get_subset_streamlines(sft, max_streamlines, rng_seed=None):
    """
    Extract a specific number of streamlines.

    Parameters
    ----------
    sft: StatefulTractogram
        SFT containing the streamlines to subsample.
    max_streamlines: int
        Maximum number of streamlines to output.
    rng_seed: int
        Random number to use for shuffling the data.

    Return
    ------
    subset_sft: StatefulTractogram
        The filtered streamlines as a sft.
    """

    rng = np.random.RandomState(rng_seed)
    ind = np.arange(len(sft.streamlines))
    rng.shuffle(ind)

    subset_streamlines = list(
        np.asarray(sft.streamlines)[ind[:max_streamlines]])
    subset_data_per_point = sft.data_per_point[ind[:max_streamlines]]
    subset_data_per_streamline = sft.data_per_streamline[ind[:max_streamlines]]

    subset_sft = StatefulTractogram.from_sft(
        subset_streamlines,
        sft,
        data_per_point=subset_data_per_point,
        data_per_streamline=subset_data_per_streamline)

    return subset_sft
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not args.out_tractogram.endswith('.trk'):
        parser.error('Output file needs to end with .trk.')

    if len(args.color) == 7:
        args.color = '0x' + args.color.lstrip('#')

    if len(args.color) == 8:
        color_int = int(args.color, 0)
        red = color_int >> 16
        green = (color_int & 0x00FF00) >> 8
        blue = color_int & 0x0000FF
    else:
        parser.error('Hexadecimal RGB color should be formatted as "#RRGGBB"'
                     ' or 0xRRGGBB.')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    sft.data_per_point["color"] = [np.tile([red, green, blue],
                                           (len(i), 1)) for i in sft.streamlines]

    sft = StatefulTractogram.from_sft(sft.streamlines, sft,
                                      data_per_point=sft.data_per_point)

    save_tractogram(sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram, args.reference)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    ori_len = len(sft)
    sft.remove_invalid_streamlines()

    indices = []
    if args.remove_single_point:
        # Will try to do a PR in Dipy
        indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1]

    if args.remove_overlapping_points:
        for i in range(len(sft)):
            norm = np.linalg.norm(np.gradient(sft.streamlines[i], axis=0),
                                  axis=1)
            if (norm < 0.001).any():
                indices.append(i)

    indices = np.setdiff1d(range(len(sft)), indices)
    new_sft = StatefulTractogram.from_sft(
        sft.streamlines[indices],
        sft,
        data_per_point=sft.data_per_point[indices],
        data_per_streamline=sft.data_per_streamline[indices])
    logging.warning('Removed {} invalid streamlines.'.format(ori_len -
                                                             len(new_sft)))
    save_tractogram(new_sft, args.out_tractogram)
Exemplo n.º 5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.DEBUG
    logging.basicConfig(level=log_level)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    smoothed_streamlines = []
    for streamline in sft.streamlines:
        if args.gaussian:
            tmp_streamlines = smooth_line_gaussian(streamline, args.gaussian)
        else:
            tmp_streamlines = smooth_line_spline(streamline, args.spline[0],
                                                 args.spline[1])

        if args.error_rate:
            smoothed_streamlines.append(
                compress_streamlines(tmp_streamlines, args.error_rate))

    smoothed_sft = StatefulTractogram.from_sft(
        smoothed_streamlines, sft, data_per_streamline=sft.data_per_streamline)
    save_tractogram(smoothed_sft, args.out_tractogram)
Exemplo n.º 6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Exemplo n.º 7
0
def compress_sft(sft, tol_error=0.01):
    """ Compress a stateful tractogram. Uses Dipy's compress_streamlines, but
    deals with space better.

    Dipy's description:
    The compression consists in merging consecutive segments that are
    nearly collinear. The merging is achieved by removing the point the two
    segments have in common.

    The linearization process [Presseau15]_ ensures that every point being
    removed are within a certain margin (in mm) of the resulting streamline.
    Recommendations for setting this margin can be found in [Presseau15]_
    (in which they called it tolerance error).

    The compression also ensures that two consecutive points won't be too far
    from each other (precisely less or equal than `max_segment_length`mm).
    This is a tradeoff to speed up the linearization process [Rheault15]_. A
    low value will result in a faster linearization but low compression,
    whereas a high value will result in a slower linearization but high
    compression.

    [Presseau C. et al., A new compression format for fiber tracking datasets,
    NeuroImage, no 109, 73-83, 2015.]

    Parameters
    ----------
    sft: StatefulTractogram
        The sft to compress.
    tol_error: float (optional)
        Tolerance error in mm (default: 0.01). A rule of thumb is to set it
        to 0.01mm for deterministic streamlines and 0.1mm for probabilitic
        streamlines.

    Returns
    -------
    compressed_sft : StatefulTractogram
    """
    # Go to world space
    orig_space = sft.space
    sft.to_rasmm()

    # Compress streamlines
    compressed_streamlines = compress_streamlines(sft.streamlines,
                                                  tol_error=tol_error)
    if sft.data_per_point is not None:
        logging.warning("Initial StatefulTractogram contained data_per_point. "
                        "This information will not be carried in the final"
                        "tractogram.")

    compressed_sft = StatefulTractogram.from_sft(
        compressed_streamlines,
        sft,
        data_per_streamline=sft.data_per_streamline)

    # Return to original space
    compressed_sft.to_space(orig_space)

    return compressed_sft
Exemplo n.º 8
0
def filter_grid_roi_both(sft, mask_1, mask_2):
    """ Filters streamlines with one end in a mask and the other in
    another mask.

    Parameters
    ----------
    sft : StatefulTractogram
        StatefulTractogram containing the streamlines to segment.
    mask_1: numpy.ndarray
        Binary mask in which the streamlines should start or end.
    mask_2: numpy.ndarray
        Binary mask in which the streamlines should start or end.
    Returns
    -------
    ids : tuple
        Filtered sft.
        Ids of the streamlines passing through the mask.
    """
    line_based_indices = []
    sft.to_vox()
    sft.to_corner()
    streamline_vox = sft.streamlines
    # For endpoint filtering, we need to keep 2 separately
    # Could be faster for either end, but the code look cleaner like this
    line_based_indices = []
    voxel_beg = np.asarray([s[0] for s in streamline_vox],
                           dtype=np.int16).transpose(1, 0)
    voxel_end = np.asarray([s[-1] for s in streamline_vox],
                           dtype=np.int16).transpose(1, 0)

    map1_beg = map_coordinates(mask_1, voxel_beg, order=0, mode='nearest')
    map2_beg = map_coordinates(mask_2, voxel_beg, order=0, mode='nearest')

    map1_end = map_coordinates(mask_1, voxel_end, order=0, mode='nearest')
    map2_end = map_coordinates(mask_2, voxel_end, order=0, mode='nearest')
    line_based_indices = np.logical_or(np.logical_and(map1_beg, map2_end),
                                       np.logical_and(map1_end, map2_beg))

    line_based_indices = \
        np.arange(len(line_based_indices))[line_based_indices].astype(np.int32)

    # From indices to sft
    streamlines = sft.streamlines[line_based_indices]
    data_per_streamline = sft.data_per_streamline[line_based_indices]
    data_per_point = sft.data_per_point[line_based_indices]

    new_sft = StatefulTractogram.from_sft(
        streamlines,
        sft,
        data_per_streamline=data_per_streamline,
        data_per_point=data_per_point)

    return new_sft, line_based_indices
Exemplo n.º 9
0
def _warn_and_save(new_streamlines, sft):
    """Last step of the two resample functions:
    Warn that we loose data_per_point, then create resampled SFT."""

    if sft.data_per_point is not None:
        logging.debug("Initial stateful tractogram contained data_per_point. "
                      "This information will not be carried in the final"
                      "tractogram.")
    new_sft = StatefulTractogram.from_sft(
        new_streamlines, sft, data_per_streamline=sft.data_per_streamline)

    return new_sft
Exemplo n.º 10
0
def cut_outside_of_mask_streamlines(sft, binary_mask, min_len=0):
    """ Cut streamlines so their longest segment are within the bounding box
    or a binary mask.
    This function erases the data_per_point and data_per_streamline.

    Parameters
    ----------
    sft: StatefulTractogram
        The sft to cut streamlines (using a single mask with 1 entities) from.
    binary_mask: np.ndarray
        Boolean array representing the region (must contain 1 entities)
    min_len: float
        Minimum length from the resulting streamlines.

    Returns
    -------
    new_sft : StatefulTractogram
        New object with the streamlines trimmed within the mask.
    """
    sft.to_vox()
    sft.to_corner()
    streamlines = sft.streamlines

    new_streamlines = []
    for _, streamline in enumerate(streamlines):
        entry_found = False
        last_success = 0
        curr_len = 0
        longest_seq = (0, 0)
        for ind, pos in enumerate(streamline):
            pos = tuple(pos.astype(np.int16))
            if binary_mask[pos]:
                if not entry_found:
                    entry_found = True
                    last_success = ind
                    curr_len = 0
                else:
                    curr_len += 1
                    if curr_len > longest_seq[1] - longest_seq[0]:
                        longest_seq = (last_success, ind + 1)
            else:
                if entry_found:
                    entry_found = False
                    if curr_len > longest_seq[1] - longest_seq[0]:
                        longest_seq = (last_success, ind)
                        curr_len = 0
        # print(longest_seq)
        if longest_seq[1] != 0:
            new_streamlines.append(streamline[longest_seq[0]:longest_seq[1]])

    new_sft = StatefulTractogram.from_sft(new_streamlines, sft)
    return filter_streamlines_by_length(new_sft, min_length=min_len)
Exemplo n.º 11
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_tractogram, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    if args.step_size is not None:
        sft = resample_streamlines_step_size(sft, args.step_size)

    mask_img = nib.load(args.in_mask)
    binary_mask = get_data_as_mask(mask_img)

    if not is_header_compatible(sft, mask_img):
        parser.error('Incompatible header between the tractogram and mask.')

    bundle_disjoint, _ = ndi.label(binary_mask)
    unique, count = np.unique(bundle_disjoint, return_counts=True)
    if args.biggest_blob:
        val = unique[np.argmax(count[1:]) + 1]
        binary_mask[bundle_disjoint != val] = 0
        unique = [0, val]
    if len(unique) == 2:
        logging.info('The provided mask has 1 entity '
                     'cut_outside_of_mask_streamlines function selected.')
        new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
    elif len(unique) == 3:
        logging.info('The provided mask has 2 entity '
                     'cut_between_masks_streamlines function selected.')
        new_sft = cut_between_masks_streamlines(sft, binary_mask)

    else:
        logging.error('The provided mask has more than 2 entities. Cannot cut '
                      'between >2.')
        return

    if len(new_sft) == 0:
        logging.warning('No streamline intersected the provided mask. '
                        'Saving empty tractogram.')
    elif args.error_rate is not None:
        compressed_strs = [
            compress_streamlines(s, args.error_rate)
            for s in new_sft.streamlines
        ]
        new_sft = StatefulTractogram.from_sft(compressed_strs, sft)

    save_tractogram(new_sft, args.out_tractogram)
Exemplo n.º 12
0
def cut_outside_of_mask_streamlines(sft, binary_mask):
    """ Cut streamlines so their longest segment are within the bounding box
    or a binary mask.
    This function keeps the data_per_point and data_per_streamline.

    Parameters
    ----------
    sft: StatefulTractogram
        The sft to remove invalid points from.

    Returns
    -------
    new_sft : StatefulTractogram
        New object with the invalid points removed from each streamline.
    cutting_counter : int
        Number of streamlines that were cut.
    """
    new_streamlines = []
    length_list = length(sft.streamlines)
    min_len, max_len = min(length_list), max(length_list)
    for i, streamline in enumerate(sft.streamlines):
        # streamline = set_number_of_points(streamline, 100)
        entry_found = False
        last_success = 0
        curr_len = 0
        longest_seq = (0, 0)
        for ind, pos in enumerate(streamline):
            pos = tuple(pos.astype(np.int16))
            if binary_mask[pos]:
                if not entry_found:
                    entry_found = True
                    last_success = ind
                    curr_len = 0
                else:
                    curr_len += 1
                    if curr_len > longest_seq[1] - longest_seq[0]:
                        longest_seq = (last_success, ind)
            else:
                if entry_found:
                    entry_found = False
                    if curr_len > longest_seq[1] - longest_seq[0]:
                        longest_seq = (last_success, ind - 1)
                        curr_len = 0
        if longest_seq[1] != 0:
            new_streamlines.append(streamline[longest_seq[0]:longest_seq[1]])

    new_sft = StatefulTractogram.from_sft(new_streamlines, sft)
    return filter_streamlines_by_length(new_sft,
                                        min_length=min_len,
                                        max_length=max_len)
Exemplo n.º 13
0
    def _read_tg(self, tg=None):
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg
        self._tg_orig_space = self.tg.space

        if self.nb_streamlines and len(self.tg) > self.nb_streamlines:
            self.tg = StatefulTractogram.from_sft(
                dts.select_random_set_of_streamlines(self.tg.streamlines,
                                                     self.nb_streamlines),
                self.tg)

        return tg
Exemplo n.º 14
0
def add_noise_to_streamlines(sft: StatefulTractogram, noise_sigma: float,
                             noise_rng: np.random.RandomState):
    """Add gaussian noise (truncated to +/- 2*noise_sigma) to streamlines
    coordinates.

    Parameters
    ----------
    sft : StatefulTractogram
        Streamlines.
    noise_sigma : float
        Standard deviation of the gaussian noise to add to the streamlines.
        CAREFUL. We do not deal with space. Make sure your noise is in the
        same space as your sft.
        CAREFUL. Keep in mind that you need to choose noise_sigma<step_size/4.
        Then, the maximum noise would be <step_size/2. So in the worst case,
        the starting point of a segment may advance of step_size/2 while the
        ending point rewinds of step_size/2, but not further, so the direction
        of the segment won't flip.
    noise_rng : np.random.RandomState object
        Random number generator.

    Returns
    -------
    noisy_sft : StatefulTractogram
        Noisy streamlines. Note. Adding noise may create invalid streamlines
        (i.e. out of the box in voxel space). If you want to save noisy_sft,
        please perform noisy_sft.remove_invalid_streamlines() first.
    """
    logging.info("Please note your sft space is in {}. We suppose that the "
                 "noise, {}, fits.".format(sft.space, noise_sigma))

    # Perform noise addition (flattening before to go faster)
    flattened_coords = np.concatenate(sft.streamlines, axis=0)
    flattened_coords += truncnorm.rvs(-2,
                                      2,
                                      size=flattened_coords.shape,
                                      scale=noise_sigma,
                                      random_state=noise_rng)
    noisy_streamlines = split_array_at_lengths(
        flattened_coords, [len(s) for s in sft.streamlines])

    # Create output tractogram
    noisy_sft = StatefulTractogram.from_sft(
        noisy_streamlines,
        sft,
        data_per_point=sft.data_per_point,
        data_per_streamline=sft.data_per_streamline)

    return noisy_sft
Exemplo n.º 15
0
def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf):
    """
    Filter streamlines using minimum and max length.

    Parameters
    ----------
    sft: StatefulTractogram
        SFT containing the streamlines to filter.
    min_length: float
        Minimum length of streamlines, in mm.
    max_length: float
        Maximum length of streamlines, in mm.

    Return
    ------
    filtered_sft : StatefulTractogram
        A tractogram without short streamlines.
    """

    # Make sure we are in world space
    orig_space = sft.space
    sft.to_rasmm()

    if sft.streamlines:
        # Compute streamlines lengths
        lengths = length(sft.streamlines)

        # Filter lengths
        filter_stream = np.logical_and(lengths >= min_length,
                                       lengths <= max_length)
    else:
        filter_stream = []

    filtered_streamlines = list(
        np.asarray(sft.streamlines, dtype=object)[filter_stream])
    filtered_data_per_point = sft.data_per_point[filter_stream]
    filtered_data_per_streamline = sft.data_per_streamline[filter_stream]

    # Create final sft
    filtered_sft = StatefulTractogram.from_sft(
        filtered_streamlines,
        sft,
        data_per_point=filtered_data_per_point,
        data_per_streamline=filtered_data_per_streamline)

    # Return to original space
    filtered_sft.to_space(orig_space)

    return filtered_sft
def main():

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    new_streamlines, \
        new_per_point, \
        new_per_streamline = filter_streamlines_by_length(sft,
                                                          args.minL,
                                                          args.maxL)

    new_sft = StatefulTractogram.from_sft(
        new_streamlines,
        sft,
        data_per_streamline=new_per_streamline,
        data_per_point=new_per_point)

    if not new_streamlines:
        if args.no_empty:
            logging.debug("The file {} won't be written "
                          "(0 streamline).".format(args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(new_sft, args.out_tractogram)

    if args.display_counts:
        tc_bf = len(sft.streamlines)
        tc_af = len(new_streamlines)
        print(
            json.dumps(
                {
                    'tract_count_before_filtering': int(tc_bf),
                    'tract_count_after_filtering': int(tc_af)
                },
                indent=args.indent))
Exemplo n.º 17
0
def cut_between_masks_streamlines(sft, binary_mask, min_len=0):
    """ Cut streamlines so their segment are within the bounding box
    or going from binary mask #1 to binary mask #2.
    This function erases the data_per_point and data_per_streamline.

    Parameters
    ----------
    sft: StatefulTractogram
        The sft to cut streamlines (using a single mask with 2 entities) from.
    binary_mask: np.ndarray
        Boolean array representing the region (must contain 2 entities)
    min_len: float
        Minimum length from the resulting streamlines.
    Returns
    -------
    new_sft : StatefulTractogram
        New object with the streamlines trimmed within the masks.
    """
    sft.to_vox()
    sft.to_corner()
    streamlines = sft.streamlines

    density = get_endpoints_density_map(streamlines, binary_mask.shape)
    density[density > 0] = 1
    density[binary_mask == 0] = 0

    roi_data_1, roi_data_2 = split_heads_tails_kmeans(binary_mask)

    new_streamlines = []
    (indices, points_to_idx) = uncompress(streamlines, return_mapping=True)

    for strl_idx, strl in enumerate(streamlines):
        strl_indices = indices[strl_idx]

        in_strl_idx, out_strl_idx = intersects_two_rois(
            roi_data_1, roi_data_2, strl_indices)

        if in_strl_idx is not None and out_strl_idx is not None:
            points_to_indices = points_to_idx[strl_idx]
            tmp = compute_streamline_segment(strl, strl_indices, in_strl_idx,
                                             out_strl_idx, points_to_indices)
            new_streamlines.append(tmp)

    new_sft = StatefulTractogram.from_sft(new_streamlines, sft)
    return filter_streamlines_by_length(new_sft, min_length=min_len)
Exemplo n.º 18
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_centroid)

    if args.nb_points < 2:
        parser.error('--nb_points {} should be >= 2'.format(args.nb_points))

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)

    centroid_streamlines = get_streamlines_centroid(sft.streamlines,
                                                    args.nb_points)

    sft = StatefulTractogram.from_sft(centroid_streamlines, sft)

    save_tractogram(sft, args.out_centroid)
def test_create_from_sft():
    sft_1 = load_tractogram(filepath_dix['gs.tck'], filepath_dix['gs.nii'])
    sft_2 = StatefulTractogram.from_sft(
        sft_1.streamlines,
        sft_1,
        data_per_point=sft_1.data_per_point,
        data_per_streamline=sft_1.data_per_streamline)

    if not (np.array_equal(sft_1.streamlines, sft_2.streamlines)
            and sft_1.space_attributes == sft_2.space_attributes
            and sft_1.space == sft_2.space and sft_1.origin == sft_2.origin
            and sft_1.data_per_point == sft_2.data_per_point
            and sft_1.data_per_streamline == sft_2.data_per_streamline):
        raise AssertionError()

    sft_1.streamlines = np.arange(6000).reshape((100, 20, 3))
    if np.array_equal(sft_1.streamlines, sft_2.streamlines):
        raise AssertionError()
Exemplo n.º 20
0
def flip_sft(sft, flip_axes):
    flip_vector = get_axis_flip_vector(flip_axes)
    shift_vector = get_shift_vector(sft)

    flipped_streamlines = []

    streamlines = sft.streamlines

    for streamline in streamlines:
        mod_streamline = streamline + shift_vector
        mod_streamline *= flip_vector
        mod_streamline -= shift_vector
        flipped_streamlines.append(mod_streamline)

    new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft,
                                          data_per_point=sft.data_per_point,
                                          data_per_streamline=sft.data_per_streamline)
    return new_sft
Exemplo n.º 21
0
def get_binary_maps(streamlines, sft):
    """
    Extract a mask from a bundle

    Parameters
    ----------
    streamlines: list
        List of streamlines.
    dimensions: tuple of ints
        Dimensions of the mask.
    sft : StatefulTractogram
        Reference tractogram.
    invalid: bool
        If true, remove invalid streamlines from tractogram.

    Returns
    -------
    bundles_voxels: numpy.ndarray
        Mask representing the bundle volume.
    endpoints_voxels: numpy.ndarray
        Mask representing the bundle's endpoints.
    """
    dimensions = sft.dimensions
    if not len(streamlines):
        return np.zeros(dimensions), np.zeros(dimensions)
    elif len(streamlines) == 1:
        streamlines = [streamlines]
    tmp_sft = StatefulTractogram.from_sft(streamlines, sft)
    tmp_sft.to_vox()
    tmp_sft.to_corner()

    if len(tmp_sft) == 1:
        return np.zeros(dimensions), np.zeros(dimensions)

    bundles_voxels = compute_tract_counts_map(tmp_sft.streamlines,
                                              dimensions).astype(np.int16)

    endpoints_voxels = get_endpoints_density_map(tmp_sft.streamlines,
                                                 dimensions).astype(np.int16)

    bundles_voxels[bundles_voxels > 0] = 1
    endpoints_voxels[endpoints_voxels > 0] = 1

    return bundles_voxels, endpoints_voxels
Exemplo n.º 22
0
def extract_false_connections(sft, mask_1_filename, mask_2_filename,
                              dilate_endpoints):
    """
    Extract false connections based on two regions from a tractogram.

    Parameters
    ----------
    sft: StatefulTractogram
        Tractogram containing the streamlines to be extracted.
    mask_1_filename: str
        Filename of the "head" of the bundle.
    mask_2_filename: str
        Filename of the "tail" of the bundle.
    dilate_endpoints: int or None
        If set, dilate the masks for n iterations.

    Returns
    -------
    fc_sft: StatefulTractogram
        SFT of false connections.
    sft: StatefulTractogram
        SFT of remaining streamlines.
    """

    mask_1_img = nib.load(mask_1_filename)
    mask_2_img = nib.load(mask_2_filename)
    mask_1 = get_data_as_mask(mask_1_img)
    mask_2 = get_data_as_mask(mask_2_img)

    if dilate_endpoints:
        mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
        mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

    if len(sft.streamlines) > 0:
        tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft)

        streamlines = tmp_sft.streamlines
        fc_streamlines = streamlines

        fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft)
        return fc_sft, sft
    else:
        return sft, sft
Exemplo n.º 23
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    _, out_extension = os.path.splitext(args.in_tractogram)

    assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir)
    # Check only the first potential output filename
    assert_outputs_exist(parser, args, os.path.join(args.out_dir,
                                                    '{}_0{}'.format(args.out_prefix,
                                                                    out_extension)))

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines_count = len(sft.streamlines)

    if args.nb_chunk:
        chunk_size = int(streamlines_count/args.nb_chunk)
        nb_chunk = args.nb_chunk
    else:
        chunk_size = args.chunk_size
        nb_chunk = int(streamlines_count/chunk_size)+1

    # All chunks will be equal except the last one
    chunk_sizes = np.ones((nb_chunk,), dtype=np.int16) * chunk_size
    chunk_sizes[-1] += (streamlines_count - chunk_size * nb_chunk)
    curr_count = 0
    for i in range(nb_chunk):
        streamlines = sft.streamlines[curr_count:curr_count + chunk_sizes[i]]
        data_per_streamline = sft.data_per_streamline[curr_count:curr_count
                                                      + chunk_sizes[i]]
        data_per_point = sft.data_per_point[curr_count:curr_count
                                            + chunk_sizes[i]]
        curr_count += chunk_sizes[i]
        new_sft = StatefulTractogram.from_sft(streamlines, sft,
                                              data_per_point=data_per_point,
                                              data_per_streamline=data_per_streamline)

        out_name = os.path.join(args.out_dir,
                                '{0}_{1}{2}'.format(args.out_prefix, i,
                                                    out_extension))
        save_tractogram(new_sft, out_name)
Exemplo n.º 24
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    indices = np.arange(len(sft.streamlines))
    random.shuffle(indices, random=args.seed)

    streamlines = sft.streamlines[indices]
    data_per_streamline = sft.data_per_streamline[indices]
    data_per_point = sft.data_per_point[indices]

    shuffled_sft = StatefulTractogram.from_sft(
        streamlines,
        sft,
        data_per_streamline=data_per_streamline,
        data_per_point=data_per_point)
    save_tractogram(shuffled_sft, args.out_tractogram)
Exemplo n.º 25
0
def reverse_streamlines(sft: StatefulTractogram,
                        reverse_ids: np.ndarray = None):
    """Reverse streamlines, i.e. inverse the beginning and end

    Parameters
    ----------
    sft: StatefulTractogram
        Dipy object containing your streamlines
    reverse_ids: np.ndarray, optional
        List of streamlines to reverse. If not provided, all streamlines are
        reversed.

    Returns
    -------
    new_sft: StatefulTractogram
        Dipy object with reversed streamlines and data_per_point.
    """
    if reverse_ids is None:
        reverse_ids = range(len(sft.streamlines))

    new_streamlines = [
        s[::-1] if i in reverse_ids else s
        for i, s in enumerate(sft.streamlines)
    ]
    new_data_per_point = copy.deepcopy(sft.data_per_point)
    for key in sft.data_per_point:
        new_data_per_point[key] = [
            d[::-1] if i in reverse_ids else d
            for i, d in enumerate(new_data_per_point[key])
        ]

    new_sft = StatefulTractogram.from_sft(
        new_streamlines,
        sft,
        data_per_point=new_data_per_point,
        data_per_streamline=sft.data_per_streamline)

    return new_sft
Exemplo n.º 26
0
def filter_tractogram_data(tractogram, streamline_ids):
    """ Filter tractogram according to streamline ids and keep the data

    Parameters:
    -----------
    tractogram: StatefulTractogram
        Tractogram containing the data to be filtered
    streamline_ids: array_like
        List of streamline ids the data corresponds to

    Returns:
    --------
    new_tractogram: Tractogram or StatefulTractogram
        Returns a new tractogram with only the selected streamlines
        and data
    """

    streamline_ids = np.asarray(streamline_ids, dtype=int)

    assert np.all(
        np.in1d(streamline_ids, np.arange(len(tractogram.streamlines)))
    ), "Received ids outside of streamline range"

    new_streamlines = tractogram.streamlines[streamline_ids]
    new_data_per_streamline = tractogram.data_per_streamline[streamline_ids]
    new_data_per_point = tractogram.data_per_point[streamline_ids]

    # Could have been nice to deepcopy the tractogram modify the attributes in
    # place instead of creating a new one, but tractograms cant be subsampled
    # if they have data

    return StatefulTractogram.from_sft(
        new_streamlines,
        tractogram,
        data_per_point=new_data_per_point,
        data_per_streamline=new_data_per_streamline)
Exemplo n.º 27
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if (args.gt_tails and not args.gt_heads) \
            or (args.gt_heads and not args.gt_tails):
        parser.error("Both --gt_heads and --gt_tails are needed.")
    if args.gt_endpoints and (args.gt_tails or args.gt_heads):
        parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads")
    if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads):
        parser.error(
            "Either input --gt_endpoints or --gt_heads and --gt_tails.")

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    _, ext = os.path.splitext(args.in_tractogram)
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)

    if args.remove_invalid:
        sft.remove_invalid_streamlines()

    initial_count = len(sft)

    logging.info("Verifying compatibility with ground-truth")
    for gt in args.gt_bundles:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    logging.info("Computing ground-truth masks")
    gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions,  = \
        compute_gt_masks(args.gt_bundles, parser, args)

    # If endpoints without heads/tails are loaded, split them and continue
    # normally after. Q/C of the output is important
    if args.gt_endpoints:
        logging.info("Extracting ground-truth end and tail masks")
        gt_tails, gt_heads, affine, dimensions = \
            extract_tails_heads_from_endpoints(
                args.gt_endpoints, args.out_dir)
    else:
        gt_tails, gt_heads = args.gt_tails, args.gt_heads

    logging.info("Verifying compatibility with endpoints")
    for gt in gt_tails + gt_heads:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    # Load the endpoints heads/tails, keep the correct combinations
    # separately from all the possible combinations
    tc_filenames = list(zip(gt_tails, gt_heads))

    length_dict = {}
    if args.gt_config:
        with open(args.gt_config, "r") as json_file:
            length_dict = json.load(json_file)

    tc_streamlines_list = []
    wpc_streamlines_list = []
    fc_streamlines_list = []
    nc_streamlines = []

    logging.info("Scoring true connections")
    for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames):

        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)

        tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections(
            sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict,
            extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i],
            args.dilate_endpoints, args.wrong_path_as_separate)
        nc_streamlines.extend(nc)
        if len(tc_sft) > 0:
            save_tractogram(tc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_tc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(wpc_sft) > 0:
            save_tractogram(wpc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        tc_streamlines_list.append(tc_sft.streamlines)
        wpc_streamlines_list.append(wpc_sft.streamlines)
        fc_streamlines_list.append(fc_sft.streamlines)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(tc_sft.streamlines) + len(wpc_sft.streamlines) +
            len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2))

    # Again keep the keep the correct combinations
    comb_filename = list(
        itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2))

    # Remove the true connections from all combinations, leaving only
    # false connections
    for tc_f in tc_filenames:
        comb_filename.remove(tc_f)

    logging.info("Scoring false connections")
    # Go through all the possible combinations of endpoints masks
    for i, roi in enumerate(comb_filename):
        mask_1_filename, mask_2_filename = roi

        # That would be done here.
        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)
        _, ext = os.path.splitext(args.in_tractogram)

        fc_sft, sft = extract_false_connections(sft, mask_1_filename,
                                                mask_2_filename,
                                                args.dilate_endpoints)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(fc_sft.streamlines), prefix_1, prefix_2))

        fc_streamlines_list.append(fc_sft.streamlines)

    nc_streamlines.extend(sft.streamlines)

    final_results = {}
    no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft)
    save_tractogram(no_conn_sft,
                    os.path.join(args.out_dir, "nc{}".format(ext)),
                    bbox_valid_check=False)

    # Total number of streamlines for each category
    # and statistic that are not "bundle-wise"
    tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list)))
    fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list)))

    if args.wrong_path_as_separate:
        wpc_streamlines_count = len(
            list(itertools.chain(*wpc_streamlines_list)))
    else:
        wpc_streamlines_count = 0

    nc_streamlines_count = len(nc_streamlines)
    total_count = tc_streamlines_count + fc_streamlines_count + \
        wpc_streamlines_count + nc_streamlines_count

    assert total_count == initial_count

    final_results["tractogram_filename"] = str(args.in_tractogram)
    final_results["tractogram_overlap"] = 0.0
    final_results["tc_streamlines"] = tc_streamlines_count
    final_results["fc_streamlines"] = fc_streamlines_count
    final_results["nc_streamlines"] = nc_streamlines_count

    final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x])
    final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x])

    final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count
    final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count
    final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count

    if args.wrong_path_as_separate:
        final_results["wpc_streamlines"] = wpc_streamlines_count
        final_results["wpc_streamlines_ratio"] = \
            wpc_streamlines_count / total_count
        final_results["wpc_bundle"] = len(
            [x for x in wpc_streamlines_list if x])

    final_results["total_streamlines"] = total_count
    final_results["bundle_wise"] = {}
    final_results["bundle_wise"]["true_connections"] = {}
    final_results["bundle_wise"]["false_connections"] = {}
    tractogram_overlap = 0.0

    for i, filename in enumerate(tc_filenames):
        current_tc_streamlines = tc_streamlines_list[i]
        current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps(
            current_tc_streamlines, sft)

        if args.wrong_path_as_separate:
            current_wpc_streamlines = wpc_streamlines_list[i]
            current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines,
                                                    sft)

        tmp_dict = {}
        tmp_dict["tc_streamlines"] = len(current_tc_streamlines)

        tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i],
                                                 current_tc_voxels)[0]

        bundle_overlap = gt_bundle_masks[i] * current_tc_voxels
        bundle_overreach = np.zeros(dimensions)
        bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                  & (current_tc_voxels >= 1))] = 1
        bundle_lacking = np.zeros(dimensions)
        bundle_lacking[np.where((gt_bundle_masks[i] == 1)
                                & (current_tc_voxels == 0))] = 1

        if args.wrong_path_as_separate:
            tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines)
            tmp_dict["wpc_dice"] = \
                compute_dice_voxel(gt_bundle_masks[i],
                                   current_wpc_voxels)[0]
            # Add wrong path to overreach
            bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                      & (current_wpc_voxels >= 1))] = 1

        tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap)
        tmp_dict["tc_bundle_overreach"] = \
            np.count_nonzero(bundle_overreach)
        tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking)
        tmp_dict["tc_bundle_overlap_PCT"] = \
            tmp_dict["tc_bundle_overlap"] / \
            (tmp_dict["tc_bundle_overlap"] +
                tmp_dict["tc_bundle_lacking"])
        tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"]

        endpoints_overlap = \
            gt_bundle_masks[i] * current_tc_endpoints_voxels
        endpoints_overreach = np.zeros(dimensions)
        endpoints_overreach[np.where((gt_bundle_masks[i] == 0)
                                     & (current_tc_endpoints_voxels >= 1))] = 1
        tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap)
        tmp_dict["tc_endpoints_overreach"] = np.count_nonzero(
            endpoints_overreach)

        final_results["bundle_wise"]["true_connections"][str(filename)] = \
            tmp_dict

    # Bundle-wise statistics, useful for more complex phantom
    for i, filename in enumerate(comb_filename):
        current_fc_streamlines = fc_streamlines_list[i]
        current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft)

        tmp_dict = {}

        if len(current_fc_streamlines):
            tmp_dict["fc_streamlines"] = len(current_fc_streamlines)
            tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels)

            final_results["bundle_wise"]["false_connections"][str(filename)] =\
                tmp_dict

    final_results["tractogram_overlap"] = \
        tractogram_overlap / len(gt_bundle_masks)

    with open(os.path.join(args.out_dir, "results.json"), "w") as f:
        json.dump(final_results,
                  f,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Exemplo n.º 28
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Exemplo n.º 29
0
def filter_grid_roi(sft, mask, filter_type, is_exclude):
    """
    Parameters
    ----------
    sft : StatefulTractogram
        StatefulTractogram containing the streamlines to segment.
    target_mask : numpy.ndarray
        Binary mask in which the streamlines should pass.
    filter_type: str
        One of the 3 following choices, 'any', 'either_end', 'both_ends'.
    is_exclude: bool
        Value to indicate if the ROI is an AND (false) or a NOT (true).
    Returns
    -------
    ids : tuple
        Filtered sft.
        Ids of the streamlines passing through the mask.
    """
    line_based_indices = []
    if filter_type == 'any':
        line_based_indices = streamlines_in_mask(sft, mask)
    else:
        sft.to_vox()
        sft.to_corner()
        streamline_vox = sft.streamlines
        # For endpoint filtering, we need to keep 2 separately
        # Could be faster for either end, but the code look cleaner like this
        line_based_indices_1 = []
        line_based_indices_2 = []
        for i, line_vox in enumerate(streamline_vox):
            voxel_1 = tuple(line_vox[0].astype(np.int16))
            voxel_2 = tuple(line_vox[-1].astype(np.int16))
            if mask[voxel_1]:
                line_based_indices_1.append(i)
            if mask[voxel_2]:
                line_based_indices_2.append(i)

        # Both endpoints need to be in the mask (AND)
        if filter_type == 'both_ends':
            line_based_indices = np.intersect1d(line_based_indices_1,
                                                line_based_indices_2)
        # Only one endpoint need to be in the mask (OR)
        elif filter_type == 'either_end':
            line_based_indices = np.union1d(line_based_indices_1,
                                            line_based_indices_2)

    # If the 'exclude' option is used, the selection is inverted
    if is_exclude:
        line_based_indices = np.setdiff1d(range(len(sft)),
                                          np.unique(line_based_indices))
    line_based_indices = np.asarray(line_based_indices).astype(np.int32)

    # From indices to sft
    streamlines = sft.streamlines[line_based_indices]
    data_per_streamline = sft.data_per_streamline[line_based_indices]
    data_per_point = sft.data_per_point[line_based_indices]

    new_sft = StatefulTractogram.from_sft(streamlines, sft,
                                          data_per_streamline=data_per_streamline,
                                          data_per_point=data_per_point)

    return new_sft, line_based_indices
Exemplo n.º 30
0
def filter_cuboid(sft, cuboid_radius, cuboid_center,
                  filter_type, is_exclude):
    """
    Parameters
    ----------
    sft : StatefulTractogram
        StatefulTractogram containing the streamlines to segment.
    cuboid_radius : numpy.ndarray (3)
        Size in mm, x/y/z of the cuboid.
    cuboid_center: numpy.ndarray (3)
        Center x/y/z of the cuboid.
    filter_type: str
        One of the 3 following choices, 'any', 'either_end', 'both_ends'.
    is_exclude: bool
        Value to indicate if the ROI is an AND (false) or a NOT (true).
    is_in_vox: bool
        Value to indicate if the ROI is in voxel space.
    Returns
    -------
    ids : tuple
        Filtered sft.
        Ids of the streamlines passing through the mask.
    """
    pre_filtered_sft, pre_filtered_indices = \
        pre_filtering_for_geometrical_shape(sft, cuboid_radius,
                                            cuboid_center, filter_type,
                                            False)
    pre_filtered_sft.to_rasmm()
    pre_filtered_sft.to_center()
    pre_filtered_streamlines = pre_filtered_sft.streamlines
    _, _, res, _ = sft.space_attributes

    selected_by_cuboid = []
    line_based_indices_1 = []
    line_based_indices_2 = []
    # Also here I am not using a mathematical intersection and
    # I am not using vtkPolyData like in MI-Brain, so not exactly the same
    cuboid_radius = np.asarray(cuboid_radius)
    cuboid_center = np.asarray(cuboid_center)
    for i, line in enumerate(pre_filtered_streamlines):
        if filter_type == 'any':
            # Resample to 1/10 of the voxel size
            nb_points = max(int(length(line) / np.average(res) * 10), 2)
            line = set_number_of_points(line, nb_points)
            points_in_cuboid = np.abs(line - cuboid_center) / cuboid_radius
            points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0),
                                      axis=1)

            if np.argwhere(points_in_cuboid == 3).any():
                # If at least one point was in the cuboid in x/y/z,
                # we selected that streamline
                selected_by_cuboid.append(pre_filtered_indices[i])
        else:
            # Faster to do it twice than trying to do in using an array of 2
            points_in_cuboid = np.abs(line[0] - cuboid_center) / cuboid_radius
            points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0))

            if points_in_cuboid == 3:
                line_based_indices_1.append(pre_filtered_indices[i])

            points_in_cuboid = np.abs(line[-1] - cuboid_center) / cuboid_radius
            points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0))

            if points_in_cuboid == 3:
                line_based_indices_2.append(pre_filtered_indices[i])

    # Both endpoints need to be in the mask (AND)
    if filter_type == 'both_ends':
        selected_by_cuboid = np.intersect1d(line_based_indices_1,
                                            line_based_indices_2)
    # Only one endpoint need to be in the mask (OR)
    elif filter_type == 'either_end':
        selected_by_cuboid = np.union1d(line_based_indices_1,
                                        line_based_indices_2)

    # If the 'exclude' option is used, the selection is inverted
    if is_exclude:
        selected_by_cuboid = np.setdiff1d(range(len(sft)),
                                          np.unique(selected_by_cuboid))
    line_based_indices = np.asarray(selected_by_cuboid).astype(np.int32)

    # From indices to sft
    streamlines = sft.streamlines[line_based_indices]
    data_per_streamline = sft.data_per_streamline[line_based_indices]
    data_per_point = sft.data_per_point[line_based_indices]

    new_sft = StatefulTractogram.from_sft(streamlines, sft,
                                          data_per_streamline=data_per_streamline,
                                          data_per_point=data_per_point)

    return new_sft, line_based_indices