def clean_bundles(self, **kwargs): """ Clean each segmented bundle based on the Mahalnobis distance of each streamline Parameters ---------- clean_rounds : int, optional. Number of rounds of cleaning based on the Mahalanobis distance from the mean of extracted bundles. Default: 5 clean_threshold : float, optional. Threshold of cleaning based on the Mahalanobis distance (the units are standard deviations). Default: 3. min_sl : int, optional. Number of streamlines in a bundle under which we will not bother with cleaning outliers. Default: 20. stat : callable, optional. The statistic of each node relative to which the Mahalanobis is calculated. Default: `np.mean` (but can also use median, etc.) """ for bundle_name, bundle in self.bundles.items(): if bundle.data_per_streamline is not None: new_sls, idx_in_bundle = seg.clean_bundle(bundle, return_idx=True, **kwargs) new_idx = bundle.data_per_streamline['idx'][idx_in_bundle] else: new_sls = seg.clean_bundle(bundle, return_idx=False, **kwargs) new_idx = None self.bundles[bundle_name] = \ StatefulTractogram(new_sls.streamlines, self.reference, self.space, origin=self.origin, data_per_streamline={'idx': new_idx}) logging.disable(level=logging.WARNING) logging.disable(logging.NOTSET)
def clean_bundles(subses_dict, bundles_file, bundle_dict, clean_params, tracking_params, segmentation_params): img = nib.load(subses_dict['dwi_file']) sft = load_tractogram(bundles_file, img, Space.VOX) img = nib.load(subses_dict['dwi_file']) start_time = time() tgram = nib.streamlines.Tractogram([], {'bundle': []}) if clean_params['return_idx']: return_idx = {} for b in bundle_dict.keys(): if b != "whole_brain": idx = np.where( sft.data_per_streamline['bundle'] == bundle_dict[b]['uid'])[0] this_tg = StatefulTractogram(sft.streamlines[idx], img, Space.VOX) this_tg = seg.clean_bundle(this_tg, **clean_params) if clean_params['return_idx']: this_tg, this_idx = this_tg idx_file = bundles_file.split('.')[0] + '.json' with open(idx_file) as ff: bundle_idx = json.load(ff)["idx"][b] return_idx[b] = np.array(bundle_idx)[this_idx].tolist() this_tgram = nib.streamlines.Tractogram( this_tg.streamlines, data_per_streamline={ 'bundle': (len(this_tg) * [bundle_dict[b]['uid']]) }, affine_to_rasmm=img.affine) tgram = aus.add_bundles(tgram, this_tgram) sft = StatefulTractogram(tgram.streamlines, sft, Space.VOX, data_per_streamline=tgram.data_per_streamline) seg_args = get_default_args(seg.clean_bundle) for k in seg_args: if callable(seg_args[k]): seg_args[k] = seg_args[k].__name__ meta = dict(source=bundles_file, Parameters=seg_args) if clean_params['return_idx']: meta["idx"] = return_idx meta["Timing"] = time() - start_time return sft, meta
def test_segment(): segmentation = seg.Segmentation() segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec, mapping=mapping) fiber_groups = segmentation.fiber_groups # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) # Here's one of them: CST_R_sl = fiber_groups['CST_R'] # Let's make sure there are streamlines in there: npt.assert_(len(CST_R_sl) > 0) # Calculate the tract profile for a volume of all-ones: tract_profile = afq_profile(np.ones(nib.load(hardi_fdata).shape[:3]), CST_R_sl.streamlines, np.eye(4)) npt.assert_almost_equal(tract_profile, np.ones(100)) clean_sl = seg.clean_bundle(CST_R_sl) npt.assert_equal(len(clean_sl), len(CST_R_sl))
def _clean_bundles(self, row): odf_model = self.tracking_params['odf_model'] directions = self.tracking_params['directions'] seg_algo = self.segmentation_params['seg_algo'] clean_bundles_file = self._get_fname( row, f'_space-RASMM_model-{odf_model}_desc-{directions}-' f'{seg_algo}-clean_tractography.trk') if self.force_recompute or not op.exists(clean_bundles_file): bundles_file = self._segment(row) sft = load_tractogram(bundles_file, row['dwi_img'], Space.VOX) tgram = nib.streamlines.Tractogram([], {'bundle': []}) if self.clean_params['return_idx']: return_idx = {} for b in self.bundle_dict.keys(): if b != "whole_brain": idx = np.where(sft.data_per_streamline['bundle'] == self.bundle_dict[b]['uid'])[0] this_tg = StatefulTractogram(sft.streamlines[idx], row['dwi_img'], Space.VOX) this_tg = seg.clean_bundle(this_tg, **self.clean_params) if self.clean_params['return_idx']: this_tg, this_idx = this_tg idx_file = bundles_file.split('.')[0] + '_idx.json' with open(idx_file) as ff: bundle_idx = json.load(ff)[b] return_idx[b] = \ np.array(bundle_idx)[this_idx].tolist() this_tgram = nib.streamlines.Tractogram( this_tg.streamlines, data_per_streamline={ 'bundle': (len(this_tg) * [self.bundle_dict[b]['uid']]) }, affine_to_rasmm=row['dwi_affine']) tgram = aus.add_bundles(tgram, this_tgram) save_tractogram( StatefulTractogram( tgram.streamlines, sft, Space.VOX, data_per_streamline=tgram.data_per_streamline), clean_bundles_file) seg_args = get_default_args(seg.clean_bundle) for k in seg_args: if callable(seg_args[k]): seg_args[k] = seg_args[k].__name__ meta = dict(source=bundles_file, Parameters=seg_args) meta_fname = clean_bundles_file.split('.')[0] + '.json' afd.write_json(meta_fname, meta) if self.clean_params['return_idx']: afd.write_json( clean_bundles_file.split('.')[0] + '_idx.json', return_idx) return clean_bundles_file
mapping=mapping, reg_template=MNI_T2_img) fiber_groups = segmentation.fiber_groups ########################################################################## # Cleaning # -------- # Each fiber group is cleaned to exclude streamlines that are outliers in terms # of their trajector and/or length. print("Cleaning fiber groups...") for bundle in bundles: print(f"Cleaning {bundle}") print(f"Before cleaning: {len(fiber_groups[bundle]['sl'])} streamlines") new_fibers, idx_in_bundle = seg.clean_bundle(fiber_groups[bundle]['sl'], return_idx=True) print(f"Afer cleaning: {len(new_fibers)} streamlines") idx_in_global = fiber_groups[bundle]['idx'][idx_in_bundle] np.save(op.join(working_dir, f'{bundle}_idx.npy'), idx_in_global) sft = StatefulTractogram(new_fibers.streamlines, img, Space.VOX) sft.to_rasmm() save_tractogram(sft, op.join(working_dir, f'{bundle}_afq.trk'), bbox_valid_check=False) ########################################################################## # Bundle profiles # --------------- # Streamlines are represented in the original diffusion space (`Space.VOX`) and # scalar properties along the length of each bundle are queried from this
def test_segment(): templates = afd.read_templates() bundles = { 'CST_L': { 'ROIs': [templates['CST_roi1_L'], templates['CST_roi2_L']], 'rules': [True, True], 'prob_map': templates['CST_L_prob_map'], 'cross_midline': None }, 'CST_R': { 'ROIs': [templates['CST_roi1_R'], templates['CST_roi1_R']], 'rules': [True, True], 'prob_map': templates['CST_R_prob_map'], 'cross_midline': None } } segmentation = seg.Segmentation() segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec, mapping=mapping) fiber_groups = segmentation.fiber_groups # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) # Here's one of them: CST_R_sl = fiber_groups['CST_R'] # Let's make sure there are streamlines in there: npt.assert_(len(CST_R_sl) > 0) # Calculate the tract profile for a volume of all-ones: tract_profile = afq_profile(np.ones(nib.load(hardi_fdata).shape[:3]), CST_R_sl.streamlines, np.eye(4)) npt.assert_almost_equal(tract_profile, np.ones(100)) clean_sl = seg.clean_bundle(CST_R_sl) npt.assert_equal(len(clean_sl), len(CST_R_sl)) # What if you don't have probability maps? bundles = { 'CST_L': { 'ROIs': [templates['CST_roi1_L'], templates['CST_roi2_L']], 'rules': [True, True], 'cross_midline': False }, 'CST_R': { 'ROIs': [templates['CST_roi1_R'], templates['CST_roi1_R']], 'rules': [True, True], 'cross_midline': False } } segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec, mapping=mapping) fiber_groups = segmentation.fiber_groups # This condition should still hold npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['CST_R']) > 0) # Test with the return_idx kwarg set to True: segmentation = seg.Segmentation(return_idx=True) segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec, mapping=mapping) fiber_groups = segmentation.fiber_groups npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['CST_R']['sl']) > 0) npt.assert_(len(fiber_groups['CST_R']['idx']) > 0) # get bundles for reco method bundles = afd.read_hcp_atlas_16_bundles() bundle_names = ['whole_brain', 'CST_R', 'CST_L'] for key in list(bundles): if key not in bundle_names: bundles.pop(key, None) # Try recobundles method segmentation = seg.Segmentation(seg_algo='Reco', progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8)) fiber_groups = segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec) # This condition should still hold npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['CST_R']) > 0) # Test with the return_idx kwarg set to True: segmentation = seg.Segmentation(seg_algo='Reco', progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8), return_idx=True) fiber_groups = segmentation.segment(bundles, tg, hardi_fdata, hardi_fbval, hardi_fbvec) fiber_groups = segmentation.fiber_groups npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['CST_R']['sl']) > 0) npt.assert_(len(fiber_groups['CST_R']['idx']) > 0)