def transform_warp_sft(sft, linear_transfo, target, inverse=False, reverse_op=False, deformation_data=None, remove_invalid=True, cut_invalid=False): """ Transform tractogram using a affine Subsequently apply a warp from antsRegistration (optional). Remove/Cut invalid streamlines to preserve sft validity. Parameters ---------- sft: StatefulTractogram Stateful tractogram object containing the streamlines to transform. linear_transfo: numpy.ndarray Linear transformation matrix to apply to the tractogram. target: Nifti filepath, image object, header Final reference for the tractogram after registration. inverse: boolean Apply the inverse linear transformation. reverse_op: boolean Apply both transformation in the reverse order deformation_data: np.ndarray 4D array containing a 3D displacement vector in each voxel. remove_invalid: boolean Remove the streamlines landing out of the bounding box. cut_invalid: boolean Cut invalid streamlines rather than removing them. Keep the longest segment only. Return ---------- new_sft : StatefulTractogram """ # Keep track of the streamlines' original space/origin space = sft.space origin = sft.origin dtype = sft.streamlines._data.dtype sft.to_rasmm() sft.to_center() if len(sft.streamlines) == 0: return StatefulTractogram(sft.streamlines, target, Space.RASMM) if inverse: linear_transfo = np.linalg.inv(linear_transfo) if not reverse_op: streamlines = transform_streamlines(sft.streamlines, linear_transfo) else: streamlines = sft.streamlines if deformation_data is not None: if not reverse_op: affine, _, _, _ = get_reference_info(target) else: affine = sft.affine # Because of duplication, an iteration over chunks of points is # necessary for a big dataset (especially if not compressed) streamlines = ArraySequence(streamlines) nb_points = len(streamlines._data) cur_position = 0 chunk_size = 1000000 nb_iteration = int(np.ceil(nb_points / chunk_size)) inv_affine = np.linalg.inv(affine) while nb_iteration > 0: max_position = min(cur_position + chunk_size, nb_points) points = streamlines._data[cur_position:max_position] # To access the deformation information, we need to go in VOX space # No need for corner shift since we are doing interpolation cur_points_vox = np.array(transform_streamlines( points, inv_affine)).T x_def = map_coordinates(deformation_data[..., 0], cur_points_vox.tolist(), order=1) y_def = map_coordinates(deformation_data[..., 1], cur_points_vox.tolist(), order=1) z_def = map_coordinates(deformation_data[..., 2], cur_points_vox.tolist(), order=1) # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs final_points = np.array([-1 * x_def, -1 * y_def, z_def]) final_points += np.array(points).T streamlines._data[cur_position:max_position] = final_points.T cur_position = max_position nb_iteration -= 1 if reverse_op: streamlines = transform_streamlines(streamlines, linear_transfo) streamlines._data = streamlines._data.astype(dtype) new_sft = StatefulTractogram(streamlines, target, Space.RASMM, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) if cut_invalid: new_sft, _ = cut_invalid_streamlines(new_sft) elif remove_invalid: new_sft.remove_invalid_streamlines() # Move the streamlines back to the original space/origin sft.to_space(space) sft.to_origin(origin) new_sft.to_space(space) new_sft.to_origin(origin) return new_sft
def load_tractogram(filename, reference, to_space=Space.RASMM, to_origin=Origin.NIFTI, bbox_valid_check=True, trk_header_check=True): """ Load the stateful tractogram from any format (trk, tck, vtk, fib, dpy) Parameters ---------- filename : string Filename with valid extension reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict), or 'same' if the input is a trk file. Reference that provides the spatial attribute. Typically a nifti-related object from the native diffusion used for streamlines generation to_space : Enum (dipy.io.stateful_tractogram.Space) Space to which the streamlines will be transformed after loading to_origin : Enum (dipy.io.stateful_tractogram.Origin) Origin to which the streamlines will be transformed after loading NIFTI standard, default (center of the voxel) TRACKVIS standard (corner of the voxel) bbox_valid_check : bool Verification for negative voxel coordinates or values above the volume dimensions. Default is True, to enforce valid file. trk_header_check : bool Verification that the reference has the same header as the spatial attributes as the input tractogram when a Trk is loaded Returns ------- output : StatefulTractogram The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']: logging.error('Output filename is not one of the supported format.') return False if to_space not in Space: logging.error('Space MUST be one of the 3 choices (Enum).') return False if reference == 'same': if extension == '.trk': reference = filename else: logging.error('Reference must be provided, "same" is only ' 'available for Trk file.') return False if trk_header_check and extension == '.trk': if not is_header_compatible(filename, reference): logging.error('Trk file header does not match the provided ' 'reference.') return False timer = time.time() data_per_point = None data_per_streamline = None if extension in ['.trk', '.tck']: tractogram_obj = nib.streamlines.load(filename).tractogram streamlines = tractogram_obj.streamlines if extension == '.trk': data_per_point = tractogram_obj.data_per_point data_per_streamline = tractogram_obj.data_per_streamline elif extension in ['.vtk', '.fib']: streamlines = load_vtk_streamlines(filename) elif extension in ['.dpy']: dpy_obj = Dpy(filename, mode='r') streamlines = list(dpy_obj.read_tracks()) dpy_obj.close() logging.debug('Load %s with %s streamlines in %s seconds.', filename, len(streamlines), round(time.time() - timer, 3)) sft = StatefulTractogram(streamlines, reference, Space.RASMM, origin=Origin.NIFTI, data_per_point=data_per_point, data_per_streamline=data_per_streamline) sft.to_space(to_space) sft.to_origin(to_origin) if bbox_valid_check and not sft.is_bbox_in_vox_valid(): raise ValueError('Bounding box is not valid in voxel space, cannot ' 'load a valid file if some coordinates are invalid.\n' 'Please set bbox_valid_check to False and then use ' 'the function remove_invalid_streamlines to discard ' 'invalid streamlines.') return sft
class Segmentation: def __init__(self, nb_points=False, nb_streamlines=False, seg_algo='AFQ', reg_algo=None, clip_edges=False, progressive=True, greater_than=50, rm_small_clusters=50, model_clust_thr=5, reduction_thr=20, refine=False, pruning_thr=5, b0_threshold=50, prob_threshold=0, rng=None, return_idx=False, filter_by_endpoints=True, dist_to_aal=4, save_intermediates=None): """ Segment streamlines into bundles. Parameters ---------- nb_points : int, boolean Resample streamlines to nb_points number of points. If False, no resampling is done. Default: False nb_streamlines : int, boolean Subsample streamlines to nb_streamlines. If False, no subsampling is don. Default: False seg_algo : string Algorithm for segmentation (case-insensitive): 'AFQ': Segment streamlines into bundles, based on inclusion/exclusion ROIs. 'Reco': Segment streamlines using the RecoBundles algorithm [Garyfallidis2017]. Default: 'AFQ' reg_algo : string or None, optional Algorithm for streamline registration (case-insensitive): 'slr' : Use Streamlinear Registration [Garyfallidis2015]_ 'syn' : Use image-based nonlinear registration If None, will use SyN if a mapping is provided, slr otherwise. If seg_algo="AFQ", SyN is always used. Default: None clip_edges : bool Whether to clip the streamlines to be only in between the ROIs. Default: False rm_small_clusters : int Using RecoBundles Algorithm. Remove clusters that have less than this value during whole brain SLR. Default: 50 model_clust_thr : int Parameter passed on to recognize for Recobundles. See Recobundles documentation. Default: 5 reduction_thr : int Parameter passed on to recognize for Recobundles. See Recobundles documentation. Default: 20 refine : bool Parameter passed on to recognize for Recobundles. See Recobundles documentation. Default: False pruning_thr : int Parameter passed on to recognize for Recobundles. See Recobundles documentation. Default: 5 progressive : boolean, optional Using RecoBundles Algorithm. Whether or not to use progressive technique during whole brain SLR. Default: True. greater_than : int, optional Using RecoBundles Algorithm. Keep streamlines that have length greater than this value during whole brain SLR. Default: 50. b0_threshold : float. Using AFQ Algorithm. All b-values with values less than or equal to `bo_threshold` are considered as b0s i.e. without diffusion weighting. Default: 50. prob_threshold : float. Using AFQ Algorithm. Initial cleaning of fiber groups is done using probability maps from [Hua2008]_. Here, we choose an average probability that needs to be exceeded for an individual streamline to be retained. Default: 0. rng : RandomState or int If None, creates RandomState. If int, creates RandomState with seed rng. Used in RecoBundles Algorithm. Default: None. return_idx : bool Whether to return the indices in the original streamlines as part of the output of segmentation. filter_by_endpoints: bool Whether to filter the bundles based on their endpoints relative to regions defined in the AAL atlas. Applies only to the waypoint approach (XXX for now). Default: True. dist_to_aal : float If filter_by_endpoints is True, this is the distance from the endpoints to the AAL atlas ROIs that is required. save_intermediates : str, optional The full path to a folder into which intermediate products are saved. Default: None, means no saving of intermediates. References ---------- .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008) Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification. Neuroimage 39: 336-347 """ self.logger = logging.getLogger('AFQ.Segmentation') self.nb_points = nb_points self.nb_streamlines = nb_streamlines if rng is None: self.rng = np.random.RandomState() elif isinstance(rng, int): self.rng = np.random.RandomState(rng) else: self.rng = rng self.seg_algo = seg_algo.lower() if reg_algo is not None: reg_algo = reg_algo.lower() self.reg_algo = reg_algo self.prob_threshold = prob_threshold self.b0_threshold = b0_threshold self.progressive = progressive self.greater_than = greater_than self.rm_small_clusters = rm_small_clusters self.model_clust_thr = model_clust_thr self.reduction_thr = reduction_thr self.refine = refine self.pruning_thr = pruning_thr self.return_idx = return_idx self.filter_by_endpoints = filter_by_endpoints self.dist_to_aal = dist_to_aal if (save_intermediates is not None) and \ (not op.exists(save_intermediates)): os.makedirs(save_intermediates, exist_ok=True) self.save_intermediates = save_intermediates self.clip_edges = clip_edges def _read_tg(self, tg=None): if tg is None: tg = self.tg else: self.tg = tg self._tg_orig_space = self.tg.space if self.nb_streamlines and len(self.tg) > self.nb_streamlines: self.tg = StatefulTractogram.from_sft( dts.select_random_set_of_streamlines(self.tg.streamlines, self.nb_streamlines), self.tg) return tg def segment(self, bundle_dict, tg, fdata=None, fbval=None, fbvec=None, mapping=None, reg_prealign=None, reg_template=None, b0_threshold=50, img_affine=None, reset_tg_space=False): """ Segment streamlines into bundles based on either waypoint ROIs [Yeatman2012]_ or RecoBundles [Garyfallidis2017]_. Parameters ---------- bundle_dict: dict Meta-data for the segmentation. The format is something like:: {'name': {'ROIs':[img1, img2], 'rules':[True, True]}, 'prob_map': img3, 'cross_midline': False} tg : StatefulTractogram Bundles to segment fdata, fbval, fbvec : str Full path to data, bvals, bvecs mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional. A mapping between DWI space and a template. If None, mapping will be registered from data used in prepare_img. Default: None. reg_prealign : array, optional. The linear transformation to be applied to align input images to the reference space before warping under the deformation field. Default: None. reg_template : str or nib.Nifti1Image, optional. Template to use for registration. Default: MNI T2. img_affine : array, optional. The spatial transformation from the measurement to the scanner space. reset_tg_space : bool, optional Whether to reset the space of the input tractogram after segmentation is complete. Default: False. Returns ------- dict : Where keys are bundle names, values are tractograms of these bundles. References ---------- .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J. Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of White Matter Properties: Automating Fiber-Tract Quantification" PloS One 7 (11): e49790. .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ if img_affine is not None: if (mapping is None or fdata is not None or fbval is not None or fbvec is not None): self.logger.error( "Provide either the full path to data, bvals, bvecs," + "or provide the affine of the image and the mapping") self.logger.info("Preparing Segmentation Parameters") self.img_affine = img_affine self.prepare_img(fdata, fbval, fbvec) self.logger.info("Preprocessing Streamlines") tg = self._read_tg(tg) # If resampling over-write the sft: if self.nb_points: self.tg = StatefulTractogram( dps.set_number_of_points(self.tg.streamlines, self.nb_points), self.tg, self.tg.space) self.prepare_map(mapping, reg_prealign, reg_template) self.bundle_dict = bundle_dict self.cross_streamlines() if self.seg_algo == "afq": fiber_groups = self.segment_afq() elif self.seg_algo.startswith("reco"): fiber_groups = self.segment_reco() else: raise ValueError(f"The seg_algo input is {self.seg_algo}, which", "is not recognized") if reset_tg_space: # Return the input to the original space when you are done: self.tg.to_space(self._tg_orig_space) return fiber_groups def prepare_img(self, fdata, fbval, fbvec): """ Prepare image data from DWI data. Parameters ---------- fdata, fbval, fbvec : str Full path to data, bvals, bvecs """ if self.img_affine is None: self.img, _, _, _ = \ ut.prepare_data(fdata, fbval, fbvec, b0_threshold=self.b0_threshold) self.img_affine = self.img.affine self.fdata = fdata self.fbval = fbval self.fbvec = fbvec def prepare_map(self, mapping=None, reg_prealign=None, reg_template=None): """ Set mapping between DWI space and a template. Parameters ---------- mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional. A mapping between DWI space and a template. If None, mapping will be registered from data used in prepare_img. Default: None. reg_template : str or nib.Nifti1Image, optional. Template to use for registration (defaults to the MNI T2) Default: None. reg_prealign : array, optional. The linear transformation to be applied to align input images to the reference space before warping under the deformation field. Default: None. """ if reg_template is None: reg_template = afd.read_mni_template() self.reg_template = reg_template if mapping is None: if self.seg_algo == "afq" or self.reg_algo == "syn": gtab = dpg.gradient_table(self.fbval, self.fbvec) self.mapping = reg.syn_register_dwi(self.fdata, gtab, template=reg_template)[1] else: self.mapping = None elif isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image): if reg_prealign is None: reg_prealign = np.eye(4) if self.img is None: self.img, _, _, _ = \ ut.prepare_data(self.fdata, self.fbval, self.fbvec, b0_threshold=self.b0_threshold) self.mapping = reg.read_mapping( mapping, self.img, reg_template, prealign=np.linalg.inv(reg_prealign)) else: self.mapping = mapping def cross_streamlines(self, tg=None, template=None, low_coord=10): """ Classify the streamlines by whether they cross the midline. Creates a crosses attribute which is an array of booleans. Each boolean corresponds to a streamline, and is whether or not that streamline crosses the midline. Parameters ---------- tg : StatefulTractogram class instance. template : nibabel.Nifti1Image class instance An affine transformation into a template space. """ if tg is None: tg = self.tg if template is None: template_affine = self.img_affine else: template_affine = template.affine # What is the x,y,z coordinate of 0,0,0 in the template space? zero_coord = np.dot(np.linalg.inv(template_affine), np.array([0, 0, 0, 1])) self.crosses = np.zeros(len(tg), dtype=bool) # already_split = 0 for sl_idx, sl in enumerate(tg.streamlines): if np.any(sl[:, 0] > zero_coord[0]) and \ np.any(sl[:, 0] < zero_coord[0]): self.crosses[sl_idx] = True else: self.crosses[sl_idx] = False def _get_bundle_info(self, bundle_idx, bundle): """ Get fiber probabilites and ROIs for a given bundle. """ rules = self.bundle_dict[bundle]['rules'] include_rois = [] exclude_rois = [] for rule_idx, rule in enumerate(rules): roi = self.bundle_dict[bundle]['ROIs'][rule_idx] if not isinstance(roi, np.ndarray): roi = roi.get_fdata() warped_roi = auv.patch_up_roi(self.mapping.transform_inverse( roi.astype(np.float32), interpolation='linear'), bundle_name=bundle) if rule: # include ROI: include_rois.append(np.array(np.where(warped_roi)).T) else: # Exclude ROI: exclude_rois.append(np.array(np.where(warped_roi)).T) # For debugging purposes, we can save the variable as it is: if self.save_intermediates is not None: nib.save( nib.Nifti1Image(warped_roi.astype(np.float32), self.img_affine), op.join(self.save_intermediates, 'warpedROI_', bundle, 'as_used.nii.gz')) # The probability map if doesn't exist is all ones with the same # shape as the ROIs: prob_map = self.bundle_dict[bundle].get('prob_map', np.ones(roi.shape)) if not isinstance(prob_map, np.ndarray): prob_map = prob_map.get_fdata() warped_prob_map = \ self.mapping.transform_inverse(prob_map, interpolation='nearest') return warped_prob_map, include_rois, exclude_rois def _check_sl_with_inclusion(self, sl, include_rois, tol): """ Helper function to check that a streamline is close to a list of inclusion ROIS. """ dist = [] for roi in include_rois: # Use squared Euclidean distance, because it's faster: dist.append(cdist(sl, roi, 'sqeuclidean')) if np.min(dist[-1]) > tol: # Too far from one of them: return False, [] # Apparently you checked all the ROIs and it was close to all of them return True, dist def _check_sl_with_exclusion(self, sl, exclude_rois, tol): """ Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for roi in exclude_rois: # Use squared Euclidean distance, because it's faster: if np.min(cdist(sl, roi, 'sqeuclidean')) < tol: return False # Either there are no exclusion ROIs, or you are not close to any: return True def _return_empty(self, bundle): """ Helper function for segment_afq, to return an empty dict under some conditions. """ if self.return_idx: self.fiber_groups[bundle] = {} self.fiber_groups[bundle]['sl'] = StatefulTractogram([], self.img, Space.VOX) self.fiber_groups[bundle]['idx'] = np.array([]) else: self.fiber_groups[bundle] = StatefulTractogram([], self.img, Space.VOX) def segment_afq(self, tg=None): """ Assign streamlines to bundles using the waypoint ROI approach Parameters ---------- tg : StatefulTractogram class instance """ tg = self._read_tg(tg=tg) self.tg.to_vox() # For expedience, we approximate each streamline as a 100 point curve. # This is only used in extracting the values from the probability map, # so will not affect measurement of distance from the waypoint ROIs fgarray = np.array(_resample_tg(tg, 100)) n_streamlines = fgarray.shape[0] streamlines_in_bundles = np.zeros( (n_streamlines, len(self.bundle_dict))) min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2), dtype=int) self.fiber_groups = {} if self.return_idx: out_idx = np.arange(n_streamlines, dtype=int) if self.filter_by_endpoints: aal_atlas = afd.read_aal_atlas(self.reg_template) if self.save_intermediates is not None: nib.save( aal_atlas['atlas'], op.join(self.save_intermediates, 'AAL_registered_to_template.nii.gz')) aal_atlas = aal_atlas['atlas'].get_fdata() # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: R = self.img_affine[0:3, 0:3] vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) dist_to_aal = self.dist_to_aal / vox_dim self.logger.info("Assigning Streamlines to Bundles") # Tolerance is set to the square of the distance to the corner # because we are using the squared Euclidean distance in calls to # `cdist` to make those calls faster. tol = dts.dist_to_corner(self.img_affine)**2 for bundle_idx, bundle in enumerate(self.bundle_dict): self.logger.info(f"Finding Streamlines for {bundle}") warped_prob_map, include_roi, exclude_roi = \ self._get_bundle_info(bundle_idx, bundle) if self.save_intermediates is not None: nib.save( nib.Nifti1Image(warped_prob_map.astype(np.float32), self.img_affine), op.join(self.save_intermediates, 'warpedprobmap', bundle, 'as_used.nii.gz')) fiber_probabilities = dts.values_from_volume( warped_prob_map, fgarray, np.eye(4)) fiber_probabilities = np.mean(fiber_probabilities, -1) idx_above_prob = np.where( fiber_probabilities > self.prob_threshold) self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed" " the probability threshold.")) crosses_midline = self.bundle_dict[bundle]['cross_midline'] for sl_idx in tqdm(idx_above_prob[0]): sl = tg.streamlines[sl_idx] if fiber_probabilities[sl_idx] > self.prob_threshold: if crosses_midline is not None: if self.crosses[sl_idx]: # This means that the streamline does # cross the midline: if crosses_midline: # This is what we want, keep going pass else: # This is not what we want, # skip to next streamline continue is_close, dist = \ self._check_sl_with_inclusion(sl, include_roi, tol) if is_close: is_far = \ self._check_sl_with_exclusion(sl, exclude_roi, tol) if is_far: min_dist_coords[sl_idx, bundle_idx, 0] =\ np.argmin(dist[0], 0)[0] min_dist_coords[sl_idx, bundle_idx, 1] =\ np.argmin(dist[1], 0)[0] streamlines_in_bundles[sl_idx, bundle_idx] =\ fiber_probabilities[sl_idx] self.logger.info( (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} " "streamlines selected with waypoint ROIs")) # Eliminate any fibers not selected using the waypoint ROIs: possible_fibers = np.sum(streamlines_in_bundles, -1) > 0 tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img, Space.VOX) if self.return_idx: out_idx = out_idx[possible_fibers] streamlines_in_bundles = streamlines_in_bundles[possible_fibers] min_dist_coords = min_dist_coords[possible_fibers] bundle_choice = np.argmax(streamlines_in_bundles, -1) # We do another round through, so that we can orient all the # streamlines within a bundle in the same orientation with respect to # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 # to ROI1). self.logger.info("Re-orienting streamlines to consistent directions") for bundle_idx, bundle in enumerate(self.bundle_dict): self.logger.info(f"Processing {bundle}") select_idx = np.where(bundle_choice == bundle_idx) if len(select_idx[0]) == 0: # There's nothing here, set and move to the next bundle: self._return_empty(bundle) continue # Use a list here, because ArraySequence doesn't support item # assignment: select_sl = list(tg.streamlines[select_idx]) # Sub-sample min_dist_coords: min_dist_coords_bundle = min_dist_coords[select_idx] for idx in range(len(select_sl)): min0 = min_dist_coords_bundle[idx, bundle_idx, 0] min1 = min_dist_coords_bundle[idx, bundle_idx, 1] if min0 > min1: select_sl[idx] = select_sl[idx][::-1] if self.filter_by_endpoints: self.logger.info("Filtering by endpoints") # Create binary masks and warp these into subject's DWI space: aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0] aal_idx = [] for targ in aal_targets: if targ is not None: aal_roi = np.zeros(aal_atlas.shape[:3]) aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1 warped_roi = self.mapping.transform_inverse( aal_roi, interpolation='nearest') aal_idx.append(np.array(np.where(warped_roi > 0)).T) else: aal_idx.append(None) self.logger.info("Before filtering " f"{len(select_sl)} streamlines") new_select_sl = clean_by_endpoints(select_sl, aal_idx[0], aal_idx[1], tol=dist_to_aal, return_idx=self.return_idx) # Generate immediately: new_select_sl = list(new_select_sl) # We need to check this again: if len(new_select_sl) == 0: # There's nothing here, set and move to the next bundle: self._return_empty(bundle) continue if self.return_idx: temp_select_sl = [] temp_select_idx = np.empty(len(new_select_sl), int) for ii, ss in enumerate(new_select_sl): temp_select_sl.append(ss[0]) temp_select_idx[ii] = ss[1] select_idx = select_idx[0][temp_select_idx] new_select_sl = temp_select_sl select_sl = new_select_sl self.logger.info("After filtering " f"{len(select_sl)} streamlines") if self.clip_edges: self.logger.info("Clipping Streamlines by ROI") for idx in range(len(select_sl)): min0 = min_dist_coords_bundle[idx, bundle_idx, 0] min1 = min_dist_coords_bundle[idx, bundle_idx, 1] # If the point that is closest to the first ROI # is the same as the point closest to the second ROI, # include the surrounding points to make a streamline. if min0 == min1: min1 = min1 + 1 min0 = min0 - 1 select_sl[idx] = select_sl[idx][min0:min1] select_sl = StatefulTractogram(select_sl, self.img, Space.RASMM) if self.return_idx: self.fiber_groups[bundle] = {} self.fiber_groups[bundle]['sl'] = select_sl self.fiber_groups[bundle]['idx'] = out_idx[select_idx] else: self.fiber_groups[bundle] = select_sl return self.fiber_groups def move_streamlines(self, tg=None, reg_algo='slr'): """Streamline-based registration of a whole-brain tractogram to the MNI whole-brain atlas. registration_algo : str "slr" or "syn" """ tg = self._read_tg(tg=tg) if reg_algo is None: if self.mapping is None: reg_algo = 'slr' else: reg_algo = 'syn' if reg_algo == "slr": self.logger.info("Registering tractogram with SLR") atlas = self.bundle_dict['whole_brain'] self.moved_sl, _, _, _ = whole_brain_slr( atlas, self.tg.streamlines, x0='affine', verbose=False, progressive=self.progressive, greater_than=self.greater_than, rm_small_clusters=self.rm_small_clusters, rng=self.rng) elif reg_algo == "syn": self.logger.info("Registering tractogram based on syn") self.tg.to_rasmm() delta = dts.values_from_volume(self.mapping.forward, self.tg.streamlines, np.eye(4)) self.moved_sl = dts.Streamlines( [d + s for d, s in zip(delta, self.tg.streamlines)]) self.tg.to_vox() if self.save_intermediates is not None: moved_sft = StatefulTractogram(self.moved_sl, self.reg_template, Space.RASMM) save_tractogram(moved_sft, op.join(self.save_intermediates, 'sls_in_mni.trk'), bbox_valid_check=False) def segment_reco(self, tg=None): """ Segment streamlines using the RecoBundles algorithm [Garyfallidis2017] Parameters ---------- tg : StatefulTractogram class instance A whole-brain tractogram to be segmented. Returns ------- fiber_groups : dict Keys are names of the bundles, values are Streamline objects. The streamlines in each object have all been oriented to have the same orientation (using `dts.orient_by_streamline`). """ tg = self._read_tg(tg=tg) fiber_groups = {} self.move_streamlines(tg, self.reg_algo) # We generate our instance of RB with the moved streamlines: self.logger.info("Extracting Bundles") rb = RecoBundles(self.moved_sl, verbose=False, rng=self.rng) # Next we'll iterate over bundles, registering each one: bundle_list = list(self.bundle_dict.keys()) bundle_list.remove('whole_brain') self.logger.info("Assigning Streamlines to Bundles") for bundle in bundle_list: model_sl = self.bundle_dict[bundle]['sl'] _, rec_labels = rb.recognize(model_bundle=model_sl, model_clust_thr=self.model_clust_thr, reduction_thr=self.reduction_thr, reduction_distance='mdf', slr=True, slr_metric='asymmetric', pruning_distance='mdf') # Use the streamlines in the original space: recognized_sl = tg.streamlines[rec_labels] if self.refine and len(recognized_sl) > 0: _, rec_labels = rb.refine(model_sl, recognized_sl, self.model_clust_thr, reduction_thr=self.reduction_thr, pruning_thr=self.pruning_thr) recognized_sl = tg.streamlines[rec_labels] standard_sl = self.bundle_dict[bundle]['centroid'] oriented_sl = dts.orient_by_streamline(recognized_sl, standard_sl) if self.return_idx: fiber_groups[bundle] = {} fiber_groups[bundle]['idx'] = rec_labels fiber_groups[bundle]['sl'] = StatefulTractogram( oriented_sl, self.img, Space.RASMM) else: fiber_groups[bundle] = StatefulTractogram( oriented_sl, self.img, Space.RASMM) self.fiber_groups = fiber_groups return fiber_groups