def probal(Threshold=.2, data_list=None, seed='.', one_node=False, two_node=False): time0 = time.time() print("begin loading data, time:", time.time() - time0) data = data_list['DWI'] affine = data_list['affine'] img = data_list['img'] labels = data_list['labels'] gtab = data_list['gtab'] head_mask = data_list['head_mask'] if type(seed) != str: seed_mask = seed else: seed_mask = (labels == 2) * (head_mask == 1) white_matter = (labels == 2) * (head_mask == 1) seeds = utils.seeds_from_mask(seed_mask, affine, density=1) print("begin reconstruction, time:", time.time() - time0) response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7) csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6) csd_fit = csd_model.fit(data, mask=white_matter) csa_model = CsaOdfModel(gtab, sh_order=6) gfa = csa_model.fit(data, mask=white_matter).gfa stopping_criterion = ThresholdStoppingCriterion(gfa, Threshold) print("begin tracking, time:", time.time() - time0) fod = csd_fit.odf(small_sphere) pmf = fod.clip(min=0) prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf, max_angle=30., sphere=small_sphere) streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds, affine, step_size=.5) streamlines = Streamlines(streamline_generator) sft = StatefulTractogram(streamlines, img, Space.RASMM) if one_node or two_node: sft.to_vox() streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node, two_node) sft = StatefulTractogram(streamlines, img, Space.VOX) sft._vox_to_rasmm() print("begin saving, time:", time.time() - time0) output = 'tractogram_probabilistic.trk' save_trk(sft, output) print("finished, time:", time.time() - time0)
def get_tractogram_in_voxel_space(tract_fname, ref_anat_fname, tracts_attribs={'orientation': 'unknown'}, origin=Origin.NIFTI): if tracts_attribs['orientation'] != 'unknown': to_lps = tracts_attribs['orientation'] == 'LPS' streamlines = load_vtk_streamlines(tract_fname, to_lps) sft = StatefulTractogram(streamlines, ref_anat_fname, Space.RASMM) else: sft = load_tractogram(tract_fname, ref_anat_fname, bbox_valid_check=False, trk_header_check=False) sft.to_vox() sft.to_origin(origin) return sft
def test_fit_data(): fdata, fbval, fbvec = dpd.get_fnames('small_25') fstreamlines = dpd.get_fnames('small_25_streamlines') gtab = grad.gradient_table(fbval, fbvec) ni_data = nib.load(fdata) data = ni_data.get_data() tensor_streamlines = nib.streamlines.load(fstreamlines).streamlines sft = StatefulTractogram(tensor_streamlines, ni_data, Space.RASMM) sft.to_vox() tensor_streamlines_vox = sft.streamlines life_model = life.FiberModel(gtab) life_fit = life_model.fit(data, tensor_streamlines_vox, np.eye(4)) model_error = life_fit.predict() - life_fit.data model_rmse = np.sqrt(np.mean(model_error**2, -1)) matlab_rmse, matlab_weights = dpd.matlab_life_results() # Lower error than the matlab implementation for these data: npt.assert_(np.median(model_rmse) < np.median(matlab_rmse)) # And a moderate correlation with the Matlab implementation weights: npt.assert_(np.corrcoef(matlab_weights, life_fit.beta)[0, 1] > 0.6)
def load_tractogram(filename, reference, to_space=Space.RASMM, shifted_origin=False, bbox_valid_check=True, trk_header_check=True): """ Load the stateful tractogram from any format (trk, tck, vtk, fib, dpy) Parameters ---------- filename : string Filename with valid extension reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict), or 'same' if the input is a trk file. Reference that provides the spatial attribute. Typically a nifti-related object from the native diffusion used for streamlines generation to_space : Enum (dipy.io.stateful_tractogram.Space) Space to which the streamlines will be transformed after loading. shifted_origin : bool Information on the position of the origin, False is Trackvis standard, default (center of the voxel) True is NIFTI standard (corner of the voxel) bbox_valid_check : bool Verification for negative voxel coordinates or values above the volume dimensions. Default is True, to enforce valid file. trk_header_check : bool Verification that the reference has the same header as the spatial attributes as the input tractogram when a Trk is loaded Returns ------- output : StatefulTractogram The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']: logging.error('Output filename is not one of the supported format') return False if to_space not in Space: logging.error('Space MUST be one of the 3 choices (Enum)') return False if reference == 'same': if extension == '.trk': reference = filename else: logging.error('Reference must be provided, "same" is only ' + 'available for Trk file.') return False if trk_header_check and extension == '.trk': if not is_header_compatible(filename, reference): logging.error('Trk file header does not match the provided ' + 'reference') return False timer = time.time() data_per_point = None data_per_streamline = None if extension in ['.trk', '.tck']: tractogram_obj = nib.streamlines.load(filename).tractogram streamlines = tractogram_obj.streamlines if extension == '.trk': data_per_point = tractogram_obj.data_per_point data_per_streamline = tractogram_obj.data_per_streamline elif extension in ['.vtk', '.fib']: streamlines = load_vtk_streamlines(filename) elif extension in ['.dpy']: dpy_obj = Dpy(filename, mode='r') streamlines = list(dpy_obj.read_tracks()) dpy_obj.close() logging.debug('Load %s with %s streamlines in %s seconds', filename, len(streamlines), round(time.time() - timer, 3)) sft = StatefulTractogram(streamlines, reference, Space.RASMM, shifted_origin=shifted_origin, data_per_point=data_per_point, data_per_streamline=data_per_streamline) if to_space == Space.VOX: sft.to_vox() elif to_space == Space.VOXMM: sft.to_voxmm() if bbox_valid_check and not sft.is_bbox_in_vox_valid(): raise ValueError('Bounding box is not valid in voxel space, cannot ' + 'load a valid file if some coordinates are invalid.' + 'Please set bbox_valid_check to False and then use' + 'the function remove_invalid_streamlines to discard' + 'invalid streamlines.') return sft
import AFQ.registration as reg import AFQ.segmentation as seg import AFQ.dti as dti from AFQ.utils.volume import patch_up_roi dpd.fetch_stanford_hardi() hardi_dir = op.join(fetcher.dipy_home, "stanford_hardi") hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz") hardi_img = nib.load(hardi_fdata) hardi_fbval = op.join(hardi_dir, "HARDI150.bval") hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() mapping = file_dict['mapping.nii.gz'] streamlines = file_dict['tractography_subsampled.trk'] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) tg.to_vox() streamlines = tg.streamlines # streamlines = dts.Streamlines( # dtu.transform_tracking_output( # streamlines[streamlines._lengths > 10], # np.linalg.inv(hardi_img.affine))) def test_segment(): templates = afd.read_templates() bundles = { 'CST_L': { 'ROIs': [templates['CST_roi1_L'], templates['CST_roi2_L']], 'rules': [True, True],
def empty_space_change(): sft = StatefulTractogram([], filepath_dix['gs.nii'], Space.VOX) sft.to_vox() sft.to_voxmm() sft.to_rasmm() assert_array_equal([], sft.streamlines.data)
class Segmentation: def __init__(self, nb_points=False, seg_algo='AFQ', progressive=True, greater_than=50, rm_small_clusters=50, model_clust_thr=40, reduction_thr=40, refine=False, pruning_thr=6, b0_threshold=0, prob_threshold=0, rng=None, return_idx=False, filter_by_endpoints=True, dist_to_aal=4): """ Segment streamlines into bundles. Parameters ---------- nb_points : int, boolean Resample streamlines to nb_points number of points. If False, no resampling is done. Default: False seg_algo : string Algorithm for segmentation (case-insensitive): 'AFQ': Segment streamlines into bundles, based on inclusion/exclusion ROIs. 'Reco': Segment streamlines using the RecoBundles algorithm [Garyfallidis2017]. Default: 'AFQ' rm_small_clusters : int Using RecoBundles Algorithm. Remove clusters that have less than this value during whole brain SLR. Default: 50 progressive : boolean, optional Using RecoBundles Algorithm. Whether or not to use progressive technique during whole brain SLR. Default: True. greater_than : int, optional Using RecoBundles Algorithm. Keep streamlines that have length greater than this value during whole brain SLR. Default: 50. b0_theshold : float. Using AFQ Algorithm. All b-values with values less than or equal to `bo_threshold` are considered as b0s i.e. without diffusion weighting. Default: 0. prob_threshold : float. Using AFQ Algorithm. Initial cleaning of fiber groups is done using probability maps from [Hua2008]_. Here, we choose an average probability that needs to be exceeded for an individual streamline to be retained. Default: 0. rng : RandomState If None, creates RandomState. Used in RecoBundles Algorithm. Default: None. return_idx : bool Whether to return the indices in the original streamlines as part of the output of segmentation. filter_by_endpoints: bool Whether to filter the bundles based on their endpoints relative to regions defined in the AAL atlas. Applies only to the waypoint approach (XXX for now). Default: True. dist_to_aal : float If filter_by_endpoints is True, this is the distance from the endpoints to the AAL atlas ROIs that is required. References ---------- .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008) Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification. Neuroimage 39: 336-347 """ self.logger = logging.getLogger('AFQ.Segmentation') self.nb_points = nb_points if rng is None: self.rng = np.random.RandomState() else: self.rng = rng self.seg_algo = seg_algo.lower() self.prob_threshold = prob_threshold self.b0_threshold = b0_threshold self.progressive = progressive self.greater_than = greater_than self.rm_small_clusters = rm_small_clusters self.model_clust_thr = model_clust_thr self.reduction_thr = reduction_thr self.refine = refine self.pruning_thr = pruning_thr self.return_idx = return_idx self.filter_by_endpoints = filter_by_endpoints self.dist_to_aal = dist_to_aal def segment(self, bundle_dict, tg, fdata=None, fbval=None, fbvec=None, mapping=None, reg_prealign=None, reg_template=None, b0_threshold=0, img_affine=None): """ Segment streamlines into bundles based on either waypoint ROIs [Yeatman2012]_ or RecoBundles [Garyfallidis2017]_. Parameters ---------- bundle_dict: dict Meta-data for the segmentation. The format is something like:: {'name': {'ROIs':[img1, img2], 'rules':[True, True]}, 'prob_map': img3, 'cross_midline': False} tg : StatefulTractogram Bundles to segment fdata, fbval, fbvec : str Full path to data, bvals, bvecs mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional. A mapping between DWI space and a template. If None, mapping will be registered from data used in prepare_img. Default: None. reg_prealign : array, optional. The linear transformation to be applied to align input images to the reference space before warping under the deformation field. Default: None. reg_template : str or nib.Nifti1Image, optional. Template to use for registration. Default: MNI T2. img_affine : array, optional. The spatial transformation from the measurement to the scanner space. References ---------- .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J. Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of White Matter Properties: Automating Fiber-Tract Quantification" PloS One 7 (11): e49790. .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ if img_affine is not None: if (mapping is None or fdata is not None or fbval is not None or fbvec is not None): self.logger.error( "Provide either the full path to data, bvals, bvecs," + "or provide the affine of the image and the mapping") self.logger.info("Preparing Segmentation Parameters") self.img_affine = img_affine self.prepare_img(fdata, fbval, fbvec) self.logger.info("Preprocessing Streamlines") self.tg = tg # If resampling over-write the sft: if self.nb_points: self.tg = StatefulTractogram( dps.set_number_of_points(self.tg.streamlines, self.nb_points), self.tg, self.tg.space) self.resample_streamlines(self.nb_points) self.prepare_map(mapping, reg_prealign, reg_template) self.bundle_dict = bundle_dict self.cross_streamlines() if self.seg_algo == "afq": return self.segment_afq() elif self.seg_algo == "reco": return self.segment_reco() def prepare_img(self, fdata, fbval, fbvec): """ Prepare image data from DWI data. Parameters ---------- fdata, fbval, fbvec : str Full path to data, bvals, bvecs """ if self.img_affine is None: self.img, _, _, _ = \ ut.prepare_data(fdata, fbval, fbvec, b0_threshold=self.b0_threshold) self.img_affine = self.img.affine self.fdata = fdata self.fbval = fbval self.fbvec = fbvec def prepare_map(self, mapping=None, reg_prealign=None, reg_template=None): """ Set mapping between DWI space and a template. Parameters ---------- mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional. A mapping between DWI space and a template. If None, mapping will be registered from data used in prepare_img. Default: None. reg_template : str or nib.Nifti1Image, optional. Template to use for registration (defaults to the MNI T2) Default: None. reg_prealign : array, optional. The linear transformation to be applied to align input images to the reference space before warping under the deformation field. Default: None. """ if reg_template is None: reg_template = dpd.read_mni_template() if mapping is None: gtab = dpg.gradient_table(self.fbval, self.fbvec) self.mapping = reg.syn_register_dwi(self.fdata, gtab)[1] elif isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image): if reg_prealign is None: reg_prealign = np.eye(4) if self.img is None: self.img, _, _, _ = \ ut.prepare_data(self.fdata, self.fbval, self.fbvec, b0_threshold=self.b0_threshold) self.mapping = reg.read_mapping( mapping, self.img, reg_template, prealign=np.linalg.inv(reg_prealign)) # replaced the pre_align with its inverse _aNNe # that fixed the warped rois and the warped prob maps which were wrong otherwise else: self.mapping = mapping def cross_streamlines(self, tg=None, template=None, low_coord=10): """ Classify the streamlines by whether they cross the midline. Creates a crosses attribute which is an array of booleans. Each boolean corresponds to a streamline, and is whether or not that streamline crosses the midline. Parameters ---------- tg : StatefulTractogram class instance. template : nibabel.Nifti1Image class instance An affine transformation into a template space. """ if tg is None: tg = self.tg if template is None: template_affine = self.img_affine else: template_affine = template.affine # What is the x,y,z coordinate of 0,0,0 in the template space? zero_coord = np.dot(np.linalg.inv(template_affine), np.array([0, 0, 0, 1])) self.crosses = np.zeros(len(tg), dtype=bool) # already_split = 0 for sl_idx, sl in enumerate(tg.streamlines): if np.any(sl[:, 0] > zero_coord[0]) and \ np.any(sl[:, 0] < zero_coord[0]): self.crosses[sl_idx] = True else: self.crosses[sl_idx] = False def _get_bundle_info(self, bundle_idx, bundle): """ Get fiber probabilites and ROIs for a given bundle. """ rules = self.bundle_dict[bundle]['rules'] include_rois = [] exclude_rois = [] for rule_idx, rule in enumerate(rules): roi = self.bundle_dict[bundle]['ROIs'][rule_idx] if not isinstance(roi, np.ndarray): roi = roi.get_fdata() warped_roi = auv.patch_up_roi( self.mapping.transform_inverse(roi.astype(np.float32), interpolation='linear')) if rule: # include ROI: include_rois.append(np.array(np.where(warped_roi)).T) else: # Exclude ROI: exclude_rois.append(np.array(np.where(warped_roi)).T) ########for debugging: save the warped ROI that is actually used in segment_afq() ########## # path_for_debugging = '/debugpath/' # nib.save(nib.Nifti1Image(warped_roi.astype(np.float32), # self.img_affine), # debugpath+'warpedROI_'+bundle+'as_used.nii.gz') ############################################################################################ # The probability map if doesn't exist is all ones with the same # shape as the ROIs: prob_map = self.bundle_dict[bundle].get('prob_map', np.ones(roi.shape)) if not isinstance(prob_map, np.ndarray): prob_map = prob_map.get_fdata() warped_prob_map = \ self.mapping.transform_inverse(prob_map, interpolation='nearest') return warped_prob_map, include_rois, exclude_rois def _check_sl_with_inclusion(self, sl, include_rois, tol): """ Helper function to check that a streamline is close to a list of inclusion ROIS. """ dist = [] for roi in include_rois: # Use squared Euclidean distance, because it's faster: dist.append(cdist(sl, roi, 'sqeuclidean')) if np.min(dist[-1]) > tol: # Too far from one of them: return False, [] # Apparently you checked all the ROIs and it was close to all of them return True, dist def _check_sl_with_exclusion(self, sl, exclude_rois, tol): """ Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for roi in exclude_rois: # Use squared Euclidean distance, because it's faster: if np.min(cdist(sl, roi, 'sqeuclidean')) < tol: return False # Either there are no exclusion ROIs, or you are not close to any: return True def _return_empty(self, bundle): """ Helper function for segment_afq, to return an empty dict under some conditions. """ if self.return_idx: self.fiber_groups[bundle] = {} self.fiber_groups[bundle]['sl'] = StatefulTractogram([], self.img, Space.VOX) self.fiber_groups[bundle]['idx'] = np.array([]) else: self.fiber_groups[bundle] = StatefulTractogram([], self.img, Space.VOX) def segment_afq(self, tg=None): """ Assign streamlines to bundles using the waypoint ROI approach Parameters ---------- tg : StatefulTractogram class instance """ if tg is None: tg = self.tg else: self.tg = tg self.tg.to_vox() # For expedience, we approximate each streamline as a 100 point curve: fgarray = np.array(_resample_tg(tg, 100)) # comment _aNNe # in general, this might cause errors: # if rois were traversed by streamlines in just a few voxels # and if streamlines were so long or resolution so high # that one hundredth of a streamline length was more than a voxel, # then the contact check below (closest distance streamline to ROI < voxel width) can fail when resampling to 100 points # To be cartain that the resampling does not cause problems, # the number of resamplign points has to be larger than the length of the streamline in voxels in native space! # end comment n_streamlines = fgarray.shape[0] streamlines_in_bundles = np.zeros( (n_streamlines, len(self.bundle_dict))) min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2), dtype=int) self.fiber_groups = {} if self.return_idx: out_idx = np.arange(n_streamlines, dtype=int) if self.filter_by_endpoints: aal_atlas = afd.read_aal_atlas()['atlas'].get_fdata() # This atlas is not yet aligned to template space resample_to = self.reg_template if isinstance(resample_to, str): resample_to = nib.load(resample_to) allVolumes = [] # aal atlas and more has mutiple volumes to represent overlapping areas separately # move through all volumes, register them to the template # put them together # safe with affine of the template # this puts aal atlas in the sam espace as template before it is warped to native space _aNNe for ii in range(aal_atlas.get_fdata().shape[-1]): vol = aal_atlas.get_fdata() vol = vol[..., ii] trafo = reg.resample( vol, # moving (according to reg.resample) resample_to, # static aal_atlas.affine, # moving affine resample_to.affine) # static affine allVolumes.append(np.asarray(trafo)) aal_atlas = np.stack(allVolumes, axis=3) aal_atlas = nib.Nifti1Image(aal_atlas, resample_to.affine) ################for debugging: save AAL Atlas after registering to template ############ # path_for_debugging = '/debugpath/' # nib.save(atlas_inFSL_space,debugpath+'AAL_registered_to_template.nii.gz') ######################################################################################### # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: R = self.img_affine[0:3, 0:3] vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) dist_to_aal = self.dist_to_aal / vox_dim self.logger.info("Assigning Streamlines to Bundles") # Tolerance is set to the square of the distance to the corner # because we are using the squared Euclidean distance in calls to # `cdist` to make those calls faster. tol = dts.dist_to_corner(self.img_affine)**2 for bundle_idx, bundle in enumerate(self.bundle_dict): self.logger.info(f"Finding Streamlines for {bundle}") warped_prob_map, include_roi, exclude_roi = \ self._get_bundle_info(bundle_idx, bundle) ########for debugging: save the warped probability map that is actually used in segment_afq() ########## # path_for_debugging = '/debugpath/' # nib.save(nib.Nifti1Image(warped_prob_map.astype(np.float32), # self.img_affine), # debugpath+'warpedprobmap_'+bundle+'as_used.nii.gz') ############################################################################################ fiber_probabilities = dts.values_from_volume( warped_prob_map, fgarray, np.eye(4)) fiber_probabilities = np.mean(fiber_probabilities, -1) idx_above_prob = np.where( fiber_probabilities > self.prob_threshold) self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed" " the probability threshold.")) crosses_midline = self.bundle_dict[bundle]['cross_midline'] for sl_idx in tqdm(idx_above_prob[0]): sl = tg.streamlines[sl_idx] if fiber_probabilities[sl_idx] > self.prob_threshold: if crosses_midline is not None: if self.crosses[sl_idx]: # This means that the streamline does # cross the midline: if crosses_midline: # This is what we want, keep going pass else: # This is not what we want, # skip to next streamline continue is_close, dist = \ self._check_sl_with_inclusion(sl, include_roi, tol) if is_close: is_far = \ self._check_sl_with_exclusion(sl, exclude_roi, tol) if is_far: min_dist_coords[sl_idx, bundle_idx, 0] =\ np.argmin(dist[0], 0)[0] min_dist_coords[sl_idx, bundle_idx, 1] =\ np.argmin(dist[1], 0)[0] streamlines_in_bundles[sl_idx, bundle_idx] =\ fiber_probabilities[sl_idx] self.logger.info( (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} " "streamlines selected with waypoint ROIs")) # Eliminate any fibers not selected using the plane ROIs: possible_fibers = np.sum(streamlines_in_bundles, -1) > 0 tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img, Space.VOX) if self.return_idx: out_idx = out_idx[possible_fibers] streamlines_in_bundles = streamlines_in_bundles[possible_fibers] min_dist_coords = min_dist_coords[possible_fibers] bundle_choice = np.argmax(streamlines_in_bundles, -1) # We do another round through, so that we can orient all the # streamlines within a bundle in the same orientation with respect to # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0 # to ROI1). self.logger.info("Re-orienting streamlines to consistent directions") for bundle_idx, bundle in enumerate(self.bundle_dict): self.logger.info(f"Processing {bundle}") select_idx = np.where(bundle_choice == bundle_idx) if len(select_idx[0]) == 0: # There's nothing here, set and move to the next bundle: self._return_empty(bundle) continue # Use a list here, because ArraySequence doesn't support item # assignment: select_sl = list(tg.streamlines[select_idx]) # Sub-sample min_dist_coords: min_dist_coords_bundle = min_dist_coords[select_idx] for idx in range(len(select_sl)): min0 = min_dist_coords_bundle[idx, bundle_idx, 0] min1 = min_dist_coords_bundle[idx, bundle_idx, 1] if min0 > min1: select_sl[idx] = select_sl[idx][::-1] # Set this to StatefulTractogram object for filtering/output: select_sl = StatefulTractogram(select_sl, self.img, Space.VOX) if self.filter_by_endpoints: self.logger.info("Filtering by endpoints") # Create binary masks and warp these into subject's DWI space: aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0] aal_idx = [] for targ in aal_targets: if targ is not None: aal_roi = np.zeros(aal_atlas.shape[:3]) aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1 warped_roi = self.mapping.transform_inverse( aal_roi, interpolation='nearest') aal_idx.append(np.array(np.where(warped_roi > 0)).T) else: aal_idx.append(None) self.logger.info("Before filtering " f"{len(select_sl)} streamlines") new_select_sl = clean_by_endpoints(select_sl.streamlines, aal_idx[0], aal_idx[1], tol=dist_to_aal, return_idx=self.return_idx) # Generate immediately: new_select_sl = list(new_select_sl) # We need to check this again: if len(new_select_sl) == 0: # There's nothing here, set and move to the next bundle: self._return_empty(bundle) continue if self.return_idx: temp_select_sl = [] temp_select_idx = np.empty(len(new_select_sl), int) for ii, ss in enumerate(new_select_sl): temp_select_sl.append(ss[0]) temp_select_idx[ii] = ss[1] select_idx = select_idx[0][temp_select_idx] new_select_sl = temp_select_sl select_sl = StatefulTractogram(new_select_sl, self.img, Space.RASMM) self.logger.info("After filtering " f"{len(select_sl)} streamlines") if self.return_idx: self.fiber_groups[bundle] = {} self.fiber_groups[bundle]['sl'] = select_sl self.fiber_groups[bundle]['idx'] = out_idx[select_idx] else: self.fiber_groups[bundle] = select_sl return self.fiber_groups def segment_reco(self, tg=None): """ Segment streamlines using the RecoBundles algorithm [Garyfallidis2017] Parameters ---------- tg : StatefulTractogram class instance A whole-brain tractogram to be segmented. Returns ------- fiber_groups : dict Keys are names of the bundles, values are Streamline objects. The streamlines in each object have all been oriented to have the same orientation (using `dts.orient_by_streamline`). """ if tg is None: tg = self.tg else: self.tg = tg fiber_groups = {} self.logger.info("Registering Whole-brain with SLR") # We start with whole-brain SLR: atlas = self.bundle_dict['whole_brain'] moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr( atlas, self.tg.streamlines, x0='affine', verbose=False, progressive=self.progressive, greater_than=self.greater_than, rm_small_clusters=self.rm_small_clusters, rng=self.rng) # We generate our instance of RB with the moved streamlines: self.logger.info("Extracting Bundles") rb = RecoBundles(moved, verbose=False, rng=self.rng) # Next we'll iterate over bundles, registering each one: bundle_list = list(self.bundle_dict.keys()) bundle_list.remove('whole_brain') self.logger.info("Assigning Streamlines to Bundles") for bundle in bundle_list: model_sl = self.bundle_dict[bundle]['sl'] _, rec_labels = rb.recognize(model_bundle=model_sl, model_clust_thr=self.model_clust_thr, reduction_thr=self.reduction_thr, reduction_distance='mdf', slr=True, slr_metric='asymmetric', pruning_distance='mdf') # Use the streamlines in the original space: recognized_sl = tg.streamlines[rec_labels] if self.refine: _, rec_labels = rb.refine(model_sl, recognized_sl, self.model_clust_thr, reduction_thr=self.reduction_thr, pruning_thr=self.pruning_thr) recognized_sl = tg.streamlines[rec_labels] standard_sl = self.bundle_dict[bundle]['centroid'] oriented_sl = dts.orient_by_streamline(recognized_sl, standard_sl) if self.return_idx: fiber_groups[bundle] = {} fiber_groups[bundle]['idx'] = rec_labels fiber_groups[bundle]['sl'] = StatefulTractogram( oriented_sl, self.img, Space.RASMM) else: fiber_groups[bundle] = StatefulTractogram( oriented_sl, self.img, Space.RASMM) self.fiber_groups = fiber_groups return fiber_groups
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.load_transfo and args.in_native_fa is None: parser.error('When loading a transformation, the final reference is ' 'needed, use --in_native_fa.') assert_inputs_exist(parser, [args.in_dsi_tractogram, args.in_dsi_fa], optional=args.in_native_fa) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram(args.in_dsi_tractogram, 'same', bbox_valid_check=False) # LPS -> RAS convention in voxel space sft.to_vox() flip_axis = ['x', 'y'] sft_fix = StatefulTractogram(sft.streamlines, args.in_dsi_fa, Space.VOXMM) sft_fix.to_vox() sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis) sft_flip = flip_sft(sft_fix, flip_axis) sft_flip.to_rasmm() sft_flip.streamlines._data -= [0.5, 0.5, -0.5] if not args.in_native_fa: if args.cut_invalid: sft_flip, _ = cut_invalid_streamlines(sft_flip) elif args.remove_invalid: sft_flip.remove_invalid_streamlines() save_tractogram(sft_flip, args.out_tractogram, bbox_valid_check=not args.keep_invalid) else: static_img = nib.load(args.in_native_fa) static_data = static_img.get_fdata() moving_img = nib.load(args.in_dsi_fa) moving_data = moving_img.get_fdata() # DSI-Studio flips the volume without changing the affine (I think) # So this has to be reversed (not the same problem as above) vox_order = get_reference_info(moving_img)[3] flip_axis = [] if vox_order[0] == 'L': moving_data = moving_data[::-1, :, :] flip_axis.append('x') if vox_order[1] == 'P': moving_data = moving_data[:, ::-1, :] flip_axis.append('y') if vox_order[2] == 'I': moving_data = moving_data[:, :, ::-1] flip_axis.append('z') sft_flip_back = flip_sft(sft_flip, flip_axis) if args.load_transfo: transfo = np.loadtxt(args.load_transfo) else: # Sometimes DSI studio has quite a lot of skull left # Dipy Median Otsu does not work with FA/GFA if args.auto_crop: moving_data = cube_crop_data(moving_data) static_data = cube_crop_data(static_data) # Since DSI Studio register to AC/PC and does not save the # transformation We must estimate the transformation, since it's # rigid it is 'easy' c_of_mass = transform_centers_of_mass(static_data, static_img.affine, moving_data, moving_img.affine) nbins = 32 sampling_prop = None level_iters = [1000, 100, 10] sigmas = [3.0, 2.0, 1.0] factors = [3, 2, 1] metric = MutualInformationMetric(nbins, sampling_prop) affreg = AffineRegistration(metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors) transform = RigidTransform3D() rigid = affreg.optimize(static_data, moving_data, transform, None, static_img.affine, moving_img.affine, starting_affine=c_of_mass.affine) transfo = rigid.affine if args.save_transfo: np.savetxt(args.save_transfo, transfo) new_sft = transform_warp_sft(sft_flip_back, transfo, static_img, inverse=True, remove_invalid=args.remove_invalid, cut_invalid=args.cut_invalid) if args.cut_invalid: new_sft, _ = cut_invalid_streamlines(new_sft) elif args.remove_invalid: new_sft.remove_invalid_streamlines() save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=not args.keep_invalid)
def sfm_tracking(name=None, data_path=None, output_path='.', Threshold=.20, data_list=None, return_streamlines=False, save_track=True, seed='.', minus_ROI_mask='.', one_node=False, two_node=False): time0 = time.time() print("begin loading data, time:", time.time() - time0) if data_list == None: data, affine, img, labels, gtab, head_mask = get_data(name, data_path) else: data = data_list['DWI'] affine = data_list['affine'] img = data_list['img'] labels = data_list['labels'] gtab = data_list['gtab'] head_mask = data_list['head_mask'] if type(seed) != str: seed_mask = seed else: seed_mask = (labels == 2) * (head_mask == 1) white_matter = (labels == 2) * (head_mask == 1) seeds = utils.seeds_from_mask(seed_mask, affine, density=1) print('begin reconstruction, time:', time.time() - time0) from dipy.reconst.csdeconv import auto_response_ssst from dipy.reconst.shm import CsaOdfModel from dipy.data import default_sphere from dipy.direction import peaks_from_model response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7) sphere = get_sphere() sf_model = sfm.SparseFascicleModel(gtab, sphere=sphere, l1_ratio=0.5, alpha=0.001, response=response[0]) pnm = peaks_from_model(sf_model, data, sphere, relative_peak_threshold=.5, min_separation_angle=25, mask=white_matter, parallel=True) stopping_criterion = ThresholdStoppingCriterion(pnm.gfa, Threshold) #seeds = utils.seeds_from_mask(white_matter, affine, density=1) print('begin tracking, time:', time.time() - time0) streamline_generator = LocalTracking(pnm, stopping_criterion, seeds, affine, step_size=.5) streamlines = Streamlines(streamline_generator) print('begin saving, time:', time.time() - time0) from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import save_trk if save_track: sft = StatefulTractogram(streamlines, img, Space.RASMM) if one_node or two_node: sft.to_vox() streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node, two_node) if type(minus_ROI_mask) != str: streamlines = minus_ROI(streamlines=streamlines, ROI=minus_ROI_mask) sft = StatefulTractogram(streamlines, img, Space.VOX) sft._vox_to_rasmm() output = output_path + '/tractogram_sfm_' + name + '.trk' save_trk(sft, output, streamlines) if return_streamlines: return streamlines
def determine(name=None, data_path=None, output_path='.', Threshold=.20, data_list=None, seed='.', minus_ROI_mask='.', one_node=False, two_node=False): time0 = time.time() print("begin loading data, time:", time.time() - time0) if data_list == None: data, affine, img, labels, gtab, head_mask = get_data(name, data_path) else: data = data_list['DWI'] affine = data_list['affine'] img = data_list['img'] labels = data_list['labels'] gtab = data_list['gtab'] head_mask = data_list['head_mask'] print(type(seed)) if type(seed) != str: seed_mask = seed else: seed_mask = (labels == 2) * (head_mask == 1) white_matter = (labels == 2) * (head_mask == 1) seeds = utils.seeds_from_mask(seed_mask, affine, density=1) print("begin reconstruction, time:", time.time() - time0) response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7) csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6) csd_fit = csd_model.fit(data, mask=white_matter) csa_model = CsaOdfModel(gtab, sh_order=6) gfa = csa_model.fit(data, mask=white_matter).gfa stopping_criterion = ThresholdStoppingCriterion(gfa, Threshold) #from dipy.data import small_sphere print("begin tracking, time:", time.time() - time0) detmax_dg = DeterministicMaximumDirectionGetter.from_shcoeff( csd_fit.shm_coeff, max_angle=30., sphere=default_sphere) streamline_generator = LocalTracking(detmax_dg, stopping_criterion, seeds, affine, step_size=.5) streamlines = Streamlines(streamline_generator) sft = StatefulTractogram(streamlines, img, Space.RASMM) if one_node or two_node: sft.to_vox() streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node, two_node) if type(minus_ROI_mask) != str: streamlines = minus_ROI(streamlines=streamlines, ROI=minus_ROI_mask) sft = StatefulTractogram(streamlines, img, Space.VOX) sft._vox_to_rasmm() print("begin saving, time:", time.time() - time0) output = output_path + '/tractogram_deterministic_' + name + '.trk' save_trk(sft, output) print("finished, time:", time.time() - time0)