def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, [], optional=args.out_centroids) assert_output_dirs_exist_and_empty(parser, args, args.out_clusters_dir, create_dir=True) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) streamlines = sft.streamlines thresholds = [40, 30, 20, args.dist_thresh] clusters = qbx_and_merge(streamlines, thresholds, nb_pts=args.nb_points, verbose=False) for i, cluster in enumerate(clusters): if len(cluster.indices) > 1: cluster_streamlines = itemgetter(*cluster.indices)(streamlines) else: cluster_streamlines = streamlines[cluster.indices] new_sft = StatefulTractogram.from_sft(cluster_streamlines, sft) save_tractogram( new_sft, os.path.join(args.out_clusters_dir, 'cluster_{}.trk'.format(i))) if args.out_centroids: new_sft = StatefulTractogram.from_sft(clusters.centroids, sft) save_tractogram(new_sft, args.out_centroids)
def get_subset_streamlines(sft, max_streamlines, rng_seed=None): """ Extract a specific number of streamlines. Parameters ---------- sft: StatefulTractogram SFT containing the streamlines to subsample. max_streamlines: int Maximum number of streamlines to output. rng_seed: int Random number to use for shuffling the data. Return ------ subset_sft: StatefulTractogram The filtered streamlines as a sft. """ rng = np.random.RandomState(rng_seed) ind = np.arange(len(sft.streamlines)) rng.shuffle(ind) subset_streamlines = list( np.asarray(sft.streamlines)[ind[:max_streamlines]]) subset_data_per_point = sft.data_per_point[ind[:max_streamlines]] subset_data_per_streamline = sft.data_per_streamline[ind[:max_streamlines]] subset_sft = StatefulTractogram.from_sft( subset_streamlines, sft, data_per_point=subset_data_per_point, data_per_streamline=subset_data_per_streamline) return subset_sft
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) if not args.out_tractogram.endswith('.trk'): parser.error('Output file needs to end with .trk.') if len(args.color) == 7: args.color = '0x' + args.color.lstrip('#') if len(args.color) == 8: color_int = int(args.color, 0) red = color_int >> 16 green = (color_int & 0x00FF00) >> 8 blue = color_int & 0x0000FF else: parser.error('Hexadecimal RGB color should be formatted as "#RRGGBB"' ' or 0xRRGGBB.') sft = load_tractogram_with_reference(parser, args, args.in_tractogram) sft.data_per_point["color"] = [np.tile([red, green, blue], (len(i), 1)) for i in sft.streamlines] sft = StatefulTractogram.from_sft(sft.streamlines, sft, data_per_point=sft.data_per_point) save_tractogram(sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram, args.reference) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) ori_len = len(sft) sft.remove_invalid_streamlines() indices = [] if args.remove_single_point: # Will try to do a PR in Dipy indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1] if args.remove_overlapping_points: for i in range(len(sft)): norm = np.linalg.norm(np.gradient(sft.streamlines[i], axis=0), axis=1) if (norm < 0.001).any(): indices.append(i) indices = np.setdiff1d(range(len(sft)), indices) new_sft = StatefulTractogram.from_sft( sft.streamlines[indices], sft, data_per_point=sft.data_per_point[indices], data_per_streamline=sft.data_per_streamline[indices]) logging.warning('Removed {} invalid streamlines.'.format(ori_len - len(new_sft))) save_tractogram(new_sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) log_level = logging.WARNING if args.verbose: log_level = logging.DEBUG logging.basicConfig(level=log_level) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) smoothed_streamlines = [] for streamline in sft.streamlines: if args.gaussian: tmp_streamlines = smooth_line_gaussian(streamline, args.gaussian) else: tmp_streamlines = smooth_line_spline(streamline, args.spline[0], args.spline[1]) if args.error_rate: smoothed_streamlines.append( compress_streamlines(tmp_streamlines, args.error_rate)) smoothed_sft = StatefulTractogram.from_sft( smoothed_streamlines, sft, data_per_streamline=sft.data_per_streamline) save_tractogram(smoothed_sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) output_streamlines_filename = '{}streamlines.trk'.format( args.output_prefix) output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix) assert_outputs_exist(parser, args, [output_voxels_filename, output_streamlines_filename]) if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1: parser.error('Ratios must be between 0 and 1.') fusion_streamlines = [] if args.reference: reference_file = args.reference else: reference_file = args.in_bundles[0] sft_list = [] for name in args.in_bundles: tmp_sft = load_tractogram_with_reference(parser, args, name) tmp_sft.to_vox() tmp_sft.to_corner() if not is_header_compatible(reference_file, tmp_sft): raise ValueError('Headers are not compatible.') sft_list.append(tmp_sft) fusion_streamlines.append(tmp_sft.streamlines) fusion_streamlines, _ = union_robust(fusion_streamlines) transformation, dimensions, _, _ = get_reference_info(reference_file) volume = np.zeros(dimensions) streamlines_vote = dok_matrix( (len(fusion_streamlines), len(args.in_bundles))) for i in range(len(args.in_bundles)): sft = sft_list[i] binary = compute_tract_counts_map(sft.streamlines, dimensions) volume[binary > 0] += 1 if args.same_tractogram: _, indices = intersection_robust( [fusion_streamlines, sft.streamlines]) streamlines_vote[list(indices), [i]] += 1 if args.same_tractogram: real_indices = [] ratio_value = int(args.ratio_streamlines * len(args.in_bundles)) real_indices = np.where( np.sum(streamlines_vote, axis=1) >= ratio_value)[0] new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices], sft_list[0]) save_tractogram(new_sft, output_streamlines_filename) volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0 volume[volume > 0] = 1 nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation), output_voxels_filename)
def compress_sft(sft, tol_error=0.01): """ Compress a stateful tractogram. Uses Dipy's compress_streamlines, but deals with space better. Dipy's description: The compression consists in merging consecutive segments that are nearly collinear. The merging is achieved by removing the point the two segments have in common. The linearization process [Presseau15]_ ensures that every point being removed are within a certain margin (in mm) of the resulting streamline. Recommendations for setting this margin can be found in [Presseau15]_ (in which they called it tolerance error). The compression also ensures that two consecutive points won't be too far from each other (precisely less or equal than `max_segment_length`mm). This is a tradeoff to speed up the linearization process [Rheault15]_. A low value will result in a faster linearization but low compression, whereas a high value will result in a slower linearization but high compression. [Presseau C. et al., A new compression format for fiber tracking datasets, NeuroImage, no 109, 73-83, 2015.] Parameters ---------- sft: StatefulTractogram The sft to compress. tol_error: float (optional) Tolerance error in mm (default: 0.01). A rule of thumb is to set it to 0.01mm for deterministic streamlines and 0.1mm for probabilitic streamlines. Returns ------- compressed_sft : StatefulTractogram """ # Go to world space orig_space = sft.space sft.to_rasmm() # Compress streamlines compressed_streamlines = compress_streamlines(sft.streamlines, tol_error=tol_error) if sft.data_per_point is not None: logging.warning("Initial StatefulTractogram contained data_per_point. " "This information will not be carried in the final" "tractogram.") compressed_sft = StatefulTractogram.from_sft( compressed_streamlines, sft, data_per_streamline=sft.data_per_streamline) # Return to original space compressed_sft.to_space(orig_space) return compressed_sft
def filter_grid_roi_both(sft, mask_1, mask_2): """ Filters streamlines with one end in a mask and the other in another mask. Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines to segment. mask_1: numpy.ndarray Binary mask in which the streamlines should start or end. mask_2: numpy.ndarray Binary mask in which the streamlines should start or end. Returns ------- ids : tuple Filtered sft. Ids of the streamlines passing through the mask. """ line_based_indices = [] sft.to_vox() sft.to_corner() streamline_vox = sft.streamlines # For endpoint filtering, we need to keep 2 separately # Could be faster for either end, but the code look cleaner like this line_based_indices = [] voxel_beg = np.asarray([s[0] for s in streamline_vox], dtype=np.int16).transpose(1, 0) voxel_end = np.asarray([s[-1] for s in streamline_vox], dtype=np.int16).transpose(1, 0) map1_beg = map_coordinates(mask_1, voxel_beg, order=0, mode='nearest') map2_beg = map_coordinates(mask_2, voxel_beg, order=0, mode='nearest') map1_end = map_coordinates(mask_1, voxel_end, order=0, mode='nearest') map2_end = map_coordinates(mask_2, voxel_end, order=0, mode='nearest') line_based_indices = np.logical_or(np.logical_and(map1_beg, map2_end), np.logical_and(map1_end, map2_beg)) line_based_indices = \ np.arange(len(line_based_indices))[line_based_indices].astype(np.int32) # From indices to sft streamlines = sft.streamlines[line_based_indices] data_per_streamline = sft.data_per_streamline[line_based_indices] data_per_point = sft.data_per_point[line_based_indices] new_sft = StatefulTractogram.from_sft( streamlines, sft, data_per_streamline=data_per_streamline, data_per_point=data_per_point) return new_sft, line_based_indices
def _warn_and_save(new_streamlines, sft): """Last step of the two resample functions: Warn that we loose data_per_point, then create resampled SFT.""" if sft.data_per_point is not None: logging.debug("Initial stateful tractogram contained data_per_point. " "This information will not be carried in the final" "tractogram.") new_sft = StatefulTractogram.from_sft( new_streamlines, sft, data_per_streamline=sft.data_per_streamline) return new_sft
def cut_outside_of_mask_streamlines(sft, binary_mask, min_len=0): """ Cut streamlines so their longest segment are within the bounding box or a binary mask. This function erases the data_per_point and data_per_streamline. Parameters ---------- sft: StatefulTractogram The sft to cut streamlines (using a single mask with 1 entities) from. binary_mask: np.ndarray Boolean array representing the region (must contain 1 entities) min_len: float Minimum length from the resulting streamlines. Returns ------- new_sft : StatefulTractogram New object with the streamlines trimmed within the mask. """ sft.to_vox() sft.to_corner() streamlines = sft.streamlines new_streamlines = [] for _, streamline in enumerate(streamlines): entry_found = False last_success = 0 curr_len = 0 longest_seq = (0, 0) for ind, pos in enumerate(streamline): pos = tuple(pos.astype(np.int16)) if binary_mask[pos]: if not entry_found: entry_found = True last_success = ind curr_len = 0 else: curr_len += 1 if curr_len > longest_seq[1] - longest_seq[0]: longest_seq = (last_success, ind + 1) else: if entry_found: entry_found = False if curr_len > longest_seq[1] - longest_seq[0]: longest_seq = (last_success, ind) curr_len = 0 # print(longest_seq) if longest_seq[1] != 0: new_streamlines.append(streamline[longest_seq[0]:longest_seq[1]]) new_sft = StatefulTractogram.from_sft(new_streamlines, sft) return filter_streamlines_by_length(new_sft, min_length=min_len)
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_inputs_exist(parser, [args.in_tractogram, args.in_mask]) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) if args.step_size is not None: sft = resample_streamlines_step_size(sft, args.step_size) mask_img = nib.load(args.in_mask) binary_mask = get_data_as_mask(mask_img) if not is_header_compatible(sft, mask_img): parser.error('Incompatible header between the tractogram and mask.') bundle_disjoint, _ = ndi.label(binary_mask) unique, count = np.unique(bundle_disjoint, return_counts=True) if args.biggest_blob: val = unique[np.argmax(count[1:]) + 1] binary_mask[bundle_disjoint != val] = 0 unique = [0, val] if len(unique) == 2: logging.info('The provided mask has 1 entity ' 'cut_outside_of_mask_streamlines function selected.') new_sft = cut_outside_of_mask_streamlines(sft, binary_mask) elif len(unique) == 3: logging.info('The provided mask has 2 entity ' 'cut_between_masks_streamlines function selected.') new_sft = cut_between_masks_streamlines(sft, binary_mask) else: logging.error('The provided mask has more than 2 entities. Cannot cut ' 'between >2.') return if len(new_sft) == 0: logging.warning('No streamline intersected the provided mask. ' 'Saving empty tractogram.') elif args.error_rate is not None: compressed_strs = [ compress_streamlines(s, args.error_rate) for s in new_sft.streamlines ] new_sft = StatefulTractogram.from_sft(compressed_strs, sft) save_tractogram(new_sft, args.out_tractogram)
def cut_outside_of_mask_streamlines(sft, binary_mask): """ Cut streamlines so their longest segment are within the bounding box or a binary mask. This function keeps the data_per_point and data_per_streamline. Parameters ---------- sft: StatefulTractogram The sft to remove invalid points from. Returns ------- new_sft : StatefulTractogram New object with the invalid points removed from each streamline. cutting_counter : int Number of streamlines that were cut. """ new_streamlines = [] length_list = length(sft.streamlines) min_len, max_len = min(length_list), max(length_list) for i, streamline in enumerate(sft.streamlines): # streamline = set_number_of_points(streamline, 100) entry_found = False last_success = 0 curr_len = 0 longest_seq = (0, 0) for ind, pos in enumerate(streamline): pos = tuple(pos.astype(np.int16)) if binary_mask[pos]: if not entry_found: entry_found = True last_success = ind curr_len = 0 else: curr_len += 1 if curr_len > longest_seq[1] - longest_seq[0]: longest_seq = (last_success, ind) else: if entry_found: entry_found = False if curr_len > longest_seq[1] - longest_seq[0]: longest_seq = (last_success, ind - 1) curr_len = 0 if longest_seq[1] != 0: new_streamlines.append(streamline[longest_seq[0]:longest_seq[1]]) new_sft = StatefulTractogram.from_sft(new_streamlines, sft) return filter_streamlines_by_length(new_sft, min_length=min_len, max_length=max_len)
def _read_tg(self, tg=None): if tg is None: tg = self.tg else: self.tg = tg self._tg_orig_space = self.tg.space if self.nb_streamlines and len(self.tg) > self.nb_streamlines: self.tg = StatefulTractogram.from_sft( dts.select_random_set_of_streamlines(self.tg.streamlines, self.nb_streamlines), self.tg) return tg
def add_noise_to_streamlines(sft: StatefulTractogram, noise_sigma: float, noise_rng: np.random.RandomState): """Add gaussian noise (truncated to +/- 2*noise_sigma) to streamlines coordinates. Parameters ---------- sft : StatefulTractogram Streamlines. noise_sigma : float Standard deviation of the gaussian noise to add to the streamlines. CAREFUL. We do not deal with space. Make sure your noise is in the same space as your sft. CAREFUL. Keep in mind that you need to choose noise_sigma<step_size/4. Then, the maximum noise would be <step_size/2. So in the worst case, the starting point of a segment may advance of step_size/2 while the ending point rewinds of step_size/2, but not further, so the direction of the segment won't flip. noise_rng : np.random.RandomState object Random number generator. Returns ------- noisy_sft : StatefulTractogram Noisy streamlines. Note. Adding noise may create invalid streamlines (i.e. out of the box in voxel space). If you want to save noisy_sft, please perform noisy_sft.remove_invalid_streamlines() first. """ logging.info("Please note your sft space is in {}. We suppose that the " "noise, {}, fits.".format(sft.space, noise_sigma)) # Perform noise addition (flattening before to go faster) flattened_coords = np.concatenate(sft.streamlines, axis=0) flattened_coords += truncnorm.rvs(-2, 2, size=flattened_coords.shape, scale=noise_sigma, random_state=noise_rng) noisy_streamlines = split_array_at_lengths( flattened_coords, [len(s) for s in sft.streamlines]) # Create output tractogram noisy_sft = StatefulTractogram.from_sft( noisy_streamlines, sft, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) return noisy_sft
def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf): """ Filter streamlines using minimum and max length. Parameters ---------- sft: StatefulTractogram SFT containing the streamlines to filter. min_length: float Minimum length of streamlines, in mm. max_length: float Maximum length of streamlines, in mm. Return ------ filtered_sft : StatefulTractogram A tractogram without short streamlines. """ # Make sure we are in world space orig_space = sft.space sft.to_rasmm() if sft.streamlines: # Compute streamlines lengths lengths = length(sft.streamlines) # Filter lengths filter_stream = np.logical_and(lengths >= min_length, lengths <= max_length) else: filter_stream = [] filtered_streamlines = list( np.asarray(sft.streamlines, dtype=object)[filter_stream]) filtered_data_per_point = sft.data_per_point[filter_stream] filtered_data_per_streamline = sft.data_per_streamline[filter_stream] # Create final sft filtered_sft = StatefulTractogram.from_sft( filtered_streamlines, sft, data_per_point=filtered_data_per_point, data_per_streamline=filtered_data_per_streamline) # Return to original space filtered_sft.to_space(orig_space) return filtered_sft
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) if args.verbose: logging.basicConfig(level=logging.DEBUG) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) new_streamlines, \ new_per_point, \ new_per_streamline = filter_streamlines_by_length(sft, args.minL, args.maxL) new_sft = StatefulTractogram.from_sft( new_streamlines, sft, data_per_streamline=new_per_streamline, data_per_point=new_per_point) if not new_streamlines: if args.no_empty: logging.debug("The file {} won't be written " "(0 streamline).".format(args.out_tractogram)) return logging.debug('The file {} contains 0 streamline'.format( args.out_tractogram)) save_tractogram(new_sft, args.out_tractogram) if args.display_counts: tc_bf = len(sft.streamlines) tc_af = len(new_streamlines) print( json.dumps( { 'tract_count_before_filtering': int(tc_bf), 'tract_count_after_filtering': int(tc_af) }, indent=args.indent))
def cut_between_masks_streamlines(sft, binary_mask, min_len=0): """ Cut streamlines so their segment are within the bounding box or going from binary mask #1 to binary mask #2. This function erases the data_per_point and data_per_streamline. Parameters ---------- sft: StatefulTractogram The sft to cut streamlines (using a single mask with 2 entities) from. binary_mask: np.ndarray Boolean array representing the region (must contain 2 entities) min_len: float Minimum length from the resulting streamlines. Returns ------- new_sft : StatefulTractogram New object with the streamlines trimmed within the masks. """ sft.to_vox() sft.to_corner() streamlines = sft.streamlines density = get_endpoints_density_map(streamlines, binary_mask.shape) density[density > 0] = 1 density[binary_mask == 0] = 0 roi_data_1, roi_data_2 = split_heads_tails_kmeans(binary_mask) new_streamlines = [] (indices, points_to_idx) = uncompress(streamlines, return_mapping=True) for strl_idx, strl in enumerate(streamlines): strl_indices = indices[strl_idx] in_strl_idx, out_strl_idx = intersects_two_rois( roi_data_1, roi_data_2, strl_indices) if in_strl_idx is not None and out_strl_idx is not None: points_to_indices = points_to_idx[strl_idx] tmp = compute_streamline_segment(strl, strl_indices, in_strl_idx, out_strl_idx, points_to_indices) new_streamlines.append(tmp) new_sft = StatefulTractogram.from_sft(new_streamlines, sft) return filter_streamlines_by_length(new_sft, min_length=min_len)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundle) assert_outputs_exist(parser, args, args.out_centroid) if args.nb_points < 2: parser.error('--nb_points {} should be >= 2'.format(args.nb_points)) sft = load_tractogram_with_reference(parser, args, args.in_bundle) centroid_streamlines = get_streamlines_centroid(sft.streamlines, args.nb_points) sft = StatefulTractogram.from_sft(centroid_streamlines, sft) save_tractogram(sft, args.out_centroid)
def test_create_from_sft(): sft_1 = load_tractogram(filepath_dix['gs.tck'], filepath_dix['gs.nii']) sft_2 = StatefulTractogram.from_sft( sft_1.streamlines, sft_1, data_per_point=sft_1.data_per_point, data_per_streamline=sft_1.data_per_streamline) if not (np.array_equal(sft_1.streamlines, sft_2.streamlines) and sft_1.space_attributes == sft_2.space_attributes and sft_1.space == sft_2.space and sft_1.origin == sft_2.origin and sft_1.data_per_point == sft_2.data_per_point and sft_1.data_per_streamline == sft_2.data_per_streamline): raise AssertionError() sft_1.streamlines = np.arange(6000).reshape((100, 20, 3)) if np.array_equal(sft_1.streamlines, sft_2.streamlines): raise AssertionError()
def flip_sft(sft, flip_axes): flip_vector = get_axis_flip_vector(flip_axes) shift_vector = get_shift_vector(sft) flipped_streamlines = [] streamlines = sft.streamlines for streamline in streamlines: mod_streamline = streamline + shift_vector mod_streamline *= flip_vector mod_streamline -= shift_vector flipped_streamlines.append(mod_streamline) new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) return new_sft
def get_binary_maps(streamlines, sft): """ Extract a mask from a bundle Parameters ---------- streamlines: list List of streamlines. dimensions: tuple of ints Dimensions of the mask. sft : StatefulTractogram Reference tractogram. invalid: bool If true, remove invalid streamlines from tractogram. Returns ------- bundles_voxels: numpy.ndarray Mask representing the bundle volume. endpoints_voxels: numpy.ndarray Mask representing the bundle's endpoints. """ dimensions = sft.dimensions if not len(streamlines): return np.zeros(dimensions), np.zeros(dimensions) elif len(streamlines) == 1: streamlines = [streamlines] tmp_sft = StatefulTractogram.from_sft(streamlines, sft) tmp_sft.to_vox() tmp_sft.to_corner() if len(tmp_sft) == 1: return np.zeros(dimensions), np.zeros(dimensions) bundles_voxels = compute_tract_counts_map(tmp_sft.streamlines, dimensions).astype(np.int16) endpoints_voxels = get_endpoints_density_map(tmp_sft.streamlines, dimensions).astype(np.int16) bundles_voxels[bundles_voxels > 0] = 1 endpoints_voxels[endpoints_voxels > 0] = 1 return bundles_voxels, endpoints_voxels
def extract_false_connections(sft, mask_1_filename, mask_2_filename, dilate_endpoints): """ Extract false connections based on two regions from a tractogram. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. mask_1_filename: str Filename of the "head" of the bundle. mask_2_filename: str Filename of the "tail" of the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. Returns ------- fc_sft: StatefulTractogram SFT of false connections. sft: StatefulTractogram SFT of remaining streamlines. """ mask_1_img = nib.load(mask_1_filename) mask_2_img = nib.load(mask_2_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) if len(sft.streamlines) > 0: tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft) streamlines = tmp_sft.streamlines fc_streamlines = streamlines fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft) return fc_sft, sft else: return sft, sft
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) _, out_extension = os.path.splitext(args.in_tractogram) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) # Check only the first potential output filename assert_outputs_exist(parser, args, os.path.join(args.out_dir, '{}_0{}'.format(args.out_prefix, out_extension))) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) streamlines_count = len(sft.streamlines) if args.nb_chunk: chunk_size = int(streamlines_count/args.nb_chunk) nb_chunk = args.nb_chunk else: chunk_size = args.chunk_size nb_chunk = int(streamlines_count/chunk_size)+1 # All chunks will be equal except the last one chunk_sizes = np.ones((nb_chunk,), dtype=np.int16) * chunk_size chunk_sizes[-1] += (streamlines_count - chunk_size * nb_chunk) curr_count = 0 for i in range(nb_chunk): streamlines = sft.streamlines[curr_count:curr_count + chunk_sizes[i]] data_per_streamline = sft.data_per_streamline[curr_count:curr_count + chunk_sizes[i]] data_per_point = sft.data_per_point[curr_count:curr_count + chunk_sizes[i]] curr_count += chunk_sizes[i] new_sft = StatefulTractogram.from_sft(streamlines, sft, data_per_point=data_per_point, data_per_streamline=data_per_streamline) out_name = os.path.join(args.out_dir, '{0}_{1}{2}'.format(args.out_prefix, i, out_extension)) save_tractogram(new_sft, out_name)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) indices = np.arange(len(sft.streamlines)) random.shuffle(indices, random=args.seed) streamlines = sft.streamlines[indices] data_per_streamline = sft.data_per_streamline[indices] data_per_point = sft.data_per_point[indices] shuffled_sft = StatefulTractogram.from_sft( streamlines, sft, data_per_streamline=data_per_streamline, data_per_point=data_per_point) save_tractogram(shuffled_sft, args.out_tractogram)
def reverse_streamlines(sft: StatefulTractogram, reverse_ids: np.ndarray = None): """Reverse streamlines, i.e. inverse the beginning and end Parameters ---------- sft: StatefulTractogram Dipy object containing your streamlines reverse_ids: np.ndarray, optional List of streamlines to reverse. If not provided, all streamlines are reversed. Returns ------- new_sft: StatefulTractogram Dipy object with reversed streamlines and data_per_point. """ if reverse_ids is None: reverse_ids = range(len(sft.streamlines)) new_streamlines = [ s[::-1] if i in reverse_ids else s for i, s in enumerate(sft.streamlines) ] new_data_per_point = copy.deepcopy(sft.data_per_point) for key in sft.data_per_point: new_data_per_point[key] = [ d[::-1] if i in reverse_ids else d for i, d in enumerate(new_data_per_point[key]) ] new_sft = StatefulTractogram.from_sft( new_streamlines, sft, data_per_point=new_data_per_point, data_per_streamline=sft.data_per_streamline) return new_sft
def filter_tractogram_data(tractogram, streamline_ids): """ Filter tractogram according to streamline ids and keep the data Parameters: ----------- tractogram: StatefulTractogram Tractogram containing the data to be filtered streamline_ids: array_like List of streamline ids the data corresponds to Returns: -------- new_tractogram: Tractogram or StatefulTractogram Returns a new tractogram with only the selected streamlines and data """ streamline_ids = np.asarray(streamline_ids, dtype=int) assert np.all( np.in1d(streamline_ids, np.arange(len(tractogram.streamlines))) ), "Received ids outside of streamline range" new_streamlines = tractogram.streamlines[streamline_ids] new_data_per_streamline = tractogram.data_per_streamline[streamline_ids] new_data_per_point = tractogram.data_per_point[streamline_ids] # Could have been nice to deepcopy the tractogram modify the attributes in # place instead of creating a new one, but tractograms cant be subsampled # if they have data return StatefulTractogram.from_sft( new_streamlines, tractogram, data_per_point=new_data_per_point, data_per_streamline=new_data_per_streamline)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if (args.gt_tails and not args.gt_heads) \ or (args.gt_heads and not args.gt_tails): parser.error("Both --gt_heads and --gt_tails are needed.") if args.gt_endpoints and (args.gt_tails or args.gt_heads): parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads") if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads): parser.error( "Either input --gt_endpoints or --gt_heads and --gt_tails.") if args.verbose: logging.basicConfig(level=logging.INFO) _, ext = os.path.splitext(args.in_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) if args.remove_invalid: sft.remove_invalid_streamlines() initial_count = len(sft) logging.info("Verifying compatibility with ground-truth") for gt in args.gt_bundles: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) logging.info("Computing ground-truth masks") gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions, = \ compute_gt_masks(args.gt_bundles, parser, args) # If endpoints without heads/tails are loaded, split them and continue # normally after. Q/C of the output is important if args.gt_endpoints: logging.info("Extracting ground-truth end and tail masks") gt_tails, gt_heads, affine, dimensions = \ extract_tails_heads_from_endpoints( args.gt_endpoints, args.out_dir) else: gt_tails, gt_heads = args.gt_tails, args.gt_heads logging.info("Verifying compatibility with endpoints") for gt in gt_tails + gt_heads: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) # Load the endpoints heads/tails, keep the correct combinations # separately from all the possible combinations tc_filenames = list(zip(gt_tails, gt_heads)) length_dict = {} if args.gt_config: with open(args.gt_config, "r") as json_file: length_dict = json.load(json_file) tc_streamlines_list = [] wpc_streamlines_list = [] fc_streamlines_list = [] nc_streamlines = [] logging.info("Scoring true connections") for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames): # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections( sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict, extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i], args.dilate_endpoints, args.wrong_path_as_separate) nc_streamlines.extend(nc) if len(tc_sft) > 0: save_tractogram(tc_sft, os.path.join( args.out_dir, "{}_{}_tc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(wpc_sft) > 0: save_tractogram(wpc_sft, os.path.join( args.out_dir, "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) tc_streamlines_list.append(tc_sft.streamlines) wpc_streamlines_list.append(wpc_sft.streamlines) fc_streamlines_list.append(fc_sft.streamlines) logging.info("Recognized {} streamlines between {} and {}".format( len(tc_sft.streamlines) + len(wpc_sft.streamlines) + len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2)) # Again keep the keep the correct combinations comb_filename = list( itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2)) # Remove the true connections from all combinations, leaving only # false connections for tc_f in tc_filenames: comb_filename.remove(tc_f) logging.info("Scoring false connections") # Go through all the possible combinations of endpoints masks for i, roi in enumerate(comb_filename): mask_1_filename, mask_2_filename = roi # That would be done here. # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) _, ext = os.path.splitext(args.in_tractogram) fc_sft, sft = extract_false_connections(sft, mask_1_filename, mask_2_filename, args.dilate_endpoints) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) logging.info("Recognized {} streamlines between {} and {}".format( len(fc_sft.streamlines), prefix_1, prefix_2)) fc_streamlines_list.append(fc_sft.streamlines) nc_streamlines.extend(sft.streamlines) final_results = {} no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft) save_tractogram(no_conn_sft, os.path.join(args.out_dir, "nc{}".format(ext)), bbox_valid_check=False) # Total number of streamlines for each category # and statistic that are not "bundle-wise" tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list))) fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list))) if args.wrong_path_as_separate: wpc_streamlines_count = len( list(itertools.chain(*wpc_streamlines_list))) else: wpc_streamlines_count = 0 nc_streamlines_count = len(nc_streamlines) total_count = tc_streamlines_count + fc_streamlines_count + \ wpc_streamlines_count + nc_streamlines_count assert total_count == initial_count final_results["tractogram_filename"] = str(args.in_tractogram) final_results["tractogram_overlap"] = 0.0 final_results["tc_streamlines"] = tc_streamlines_count final_results["fc_streamlines"] = fc_streamlines_count final_results["nc_streamlines"] = nc_streamlines_count final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x]) final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x]) final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count if args.wrong_path_as_separate: final_results["wpc_streamlines"] = wpc_streamlines_count final_results["wpc_streamlines_ratio"] = \ wpc_streamlines_count / total_count final_results["wpc_bundle"] = len( [x for x in wpc_streamlines_list if x]) final_results["total_streamlines"] = total_count final_results["bundle_wise"] = {} final_results["bundle_wise"]["true_connections"] = {} final_results["bundle_wise"]["false_connections"] = {} tractogram_overlap = 0.0 for i, filename in enumerate(tc_filenames): current_tc_streamlines = tc_streamlines_list[i] current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps( current_tc_streamlines, sft) if args.wrong_path_as_separate: current_wpc_streamlines = wpc_streamlines_list[i] current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines, sft) tmp_dict = {} tmp_dict["tc_streamlines"] = len(current_tc_streamlines) tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i], current_tc_voxels)[0] bundle_overlap = gt_bundle_masks[i] * current_tc_voxels bundle_overreach = np.zeros(dimensions) bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_voxels >= 1))] = 1 bundle_lacking = np.zeros(dimensions) bundle_lacking[np.where((gt_bundle_masks[i] == 1) & (current_tc_voxels == 0))] = 1 if args.wrong_path_as_separate: tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines) tmp_dict["wpc_dice"] = \ compute_dice_voxel(gt_bundle_masks[i], current_wpc_voxels)[0] # Add wrong path to overreach bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_wpc_voxels >= 1))] = 1 tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap) tmp_dict["tc_bundle_overreach"] = \ np.count_nonzero(bundle_overreach) tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking) tmp_dict["tc_bundle_overlap_PCT"] = \ tmp_dict["tc_bundle_overlap"] / \ (tmp_dict["tc_bundle_overlap"] + tmp_dict["tc_bundle_lacking"]) tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"] endpoints_overlap = \ gt_bundle_masks[i] * current_tc_endpoints_voxels endpoints_overreach = np.zeros(dimensions) endpoints_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_endpoints_voxels >= 1))] = 1 tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap) tmp_dict["tc_endpoints_overreach"] = np.count_nonzero( endpoints_overreach) final_results["bundle_wise"]["true_connections"][str(filename)] = \ tmp_dict # Bundle-wise statistics, useful for more complex phantom for i, filename in enumerate(comb_filename): current_fc_streamlines = fc_streamlines_list[i] current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft) tmp_dict = {} if len(current_fc_streamlines): tmp_dict["fc_streamlines"] = len(current_fc_streamlines) tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels) final_results["bundle_wise"]["false_connections"][str(filename)] =\ tmp_dict final_results["tractogram_overlap"] = \ tractogram_overlap / len(gt_bundle_masks) with open(os.path.join(args.out_dir, "results.json"), "w") as f: json.dump(final_results, f, indent=args.indent, sort_keys=args.sort_keys)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) sft.remove_invalid_streamlines() time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
def filter_grid_roi(sft, mask, filter_type, is_exclude): """ Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines to segment. target_mask : numpy.ndarray Binary mask in which the streamlines should pass. filter_type: str One of the 3 following choices, 'any', 'either_end', 'both_ends'. is_exclude: bool Value to indicate if the ROI is an AND (false) or a NOT (true). Returns ------- ids : tuple Filtered sft. Ids of the streamlines passing through the mask. """ line_based_indices = [] if filter_type == 'any': line_based_indices = streamlines_in_mask(sft, mask) else: sft.to_vox() sft.to_corner() streamline_vox = sft.streamlines # For endpoint filtering, we need to keep 2 separately # Could be faster for either end, but the code look cleaner like this line_based_indices_1 = [] line_based_indices_2 = [] for i, line_vox in enumerate(streamline_vox): voxel_1 = tuple(line_vox[0].astype(np.int16)) voxel_2 = tuple(line_vox[-1].astype(np.int16)) if mask[voxel_1]: line_based_indices_1.append(i) if mask[voxel_2]: line_based_indices_2.append(i) # Both endpoints need to be in the mask (AND) if filter_type == 'both_ends': line_based_indices = np.intersect1d(line_based_indices_1, line_based_indices_2) # Only one endpoint need to be in the mask (OR) elif filter_type == 'either_end': line_based_indices = np.union1d(line_based_indices_1, line_based_indices_2) # If the 'exclude' option is used, the selection is inverted if is_exclude: line_based_indices = np.setdiff1d(range(len(sft)), np.unique(line_based_indices)) line_based_indices = np.asarray(line_based_indices).astype(np.int32) # From indices to sft streamlines = sft.streamlines[line_based_indices] data_per_streamline = sft.data_per_streamline[line_based_indices] data_per_point = sft.data_per_point[line_based_indices] new_sft = StatefulTractogram.from_sft(streamlines, sft, data_per_streamline=data_per_streamline, data_per_point=data_per_point) return new_sft, line_based_indices
def filter_cuboid(sft, cuboid_radius, cuboid_center, filter_type, is_exclude): """ Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines to segment. cuboid_radius : numpy.ndarray (3) Size in mm, x/y/z of the cuboid. cuboid_center: numpy.ndarray (3) Center x/y/z of the cuboid. filter_type: str One of the 3 following choices, 'any', 'either_end', 'both_ends'. is_exclude: bool Value to indicate if the ROI is an AND (false) or a NOT (true). is_in_vox: bool Value to indicate if the ROI is in voxel space. Returns ------- ids : tuple Filtered sft. Ids of the streamlines passing through the mask. """ pre_filtered_sft, pre_filtered_indices = \ pre_filtering_for_geometrical_shape(sft, cuboid_radius, cuboid_center, filter_type, False) pre_filtered_sft.to_rasmm() pre_filtered_sft.to_center() pre_filtered_streamlines = pre_filtered_sft.streamlines _, _, res, _ = sft.space_attributes selected_by_cuboid = [] line_based_indices_1 = [] line_based_indices_2 = [] # Also here I am not using a mathematical intersection and # I am not using vtkPolyData like in MI-Brain, so not exactly the same cuboid_radius = np.asarray(cuboid_radius) cuboid_center = np.asarray(cuboid_center) for i, line in enumerate(pre_filtered_streamlines): if filter_type == 'any': # Resample to 1/10 of the voxel size nb_points = max(int(length(line) / np.average(res) * 10), 2) line = set_number_of_points(line, nb_points) points_in_cuboid = np.abs(line - cuboid_center) / cuboid_radius points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0), axis=1) if np.argwhere(points_in_cuboid == 3).any(): # If at least one point was in the cuboid in x/y/z, # we selected that streamline selected_by_cuboid.append(pre_filtered_indices[i]) else: # Faster to do it twice than trying to do in using an array of 2 points_in_cuboid = np.abs(line[0] - cuboid_center) / cuboid_radius points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0)) if points_in_cuboid == 3: line_based_indices_1.append(pre_filtered_indices[i]) points_in_cuboid = np.abs(line[-1] - cuboid_center) / cuboid_radius points_in_cuboid = np.sum(np.where(points_in_cuboid <= 1, 1, 0)) if points_in_cuboid == 3: line_based_indices_2.append(pre_filtered_indices[i]) # Both endpoints need to be in the mask (AND) if filter_type == 'both_ends': selected_by_cuboid = np.intersect1d(line_based_indices_1, line_based_indices_2) # Only one endpoint need to be in the mask (OR) elif filter_type == 'either_end': selected_by_cuboid = np.union1d(line_based_indices_1, line_based_indices_2) # If the 'exclude' option is used, the selection is inverted if is_exclude: selected_by_cuboid = np.setdiff1d(range(len(sft)), np.unique(selected_by_cuboid)) line_based_indices = np.asarray(selected_by_cuboid).astype(np.int32) # From indices to sft streamlines = sft.streamlines[line_based_indices] data_per_streamline = sft.data_per_streamline[line_based_indices] data_per_point = sft.data_per_point[line_based_indices] new_sft = StatefulTractogram.from_sft(streamlines, sft, data_per_streamline=data_per_streamline, data_per_point=data_per_point) return new_sft, line_based_indices