예제 #1
0
파일: api.py 프로젝트: jyeatman/pyAFQ
def _bundles(row, wm_labels, odf_model="DTI", directions="det",
             force_recompute=False):
    bundles_file = _get_fname(row,
                              '%s_%s_bundles.trk' % (odf_model,
                                                     directions))
    if not op.exists(bundles_file) or force_recompute:
        streamlines_file = _streamlines(row, wm_labels,
                                        odf_model=odf_model,
                                        directions=directions,
                                        force_recompute=force_recompute)
        tg = nib.streamlines.load(streamlines_file).tractogram
        sl = tg.apply_affine(np.linalg.inv(row['dwi_affine'])).streamlines
        bundle_dict = make_bundle_dict()
        reg_template = dpd.read_mni_template()
        mapping = reg.read_mapping(_mapping(row), row['dwi_file'],
                                   reg_template)
        bundles = seg.segment(row['dwi_file'],
                              row['bval_file'],
                              row['bvec_file'],
                              sl,
                              bundle_dict,
                              reg_template=reg_template,
                              mapping=mapping)
        tgram = _tgramer(bundles, bundle_dict, row['dwi_affine'])
        nib.streamlines.save(tgram, bundles_file)
    return bundles_file
예제 #2
0
파일: api.py 프로젝트: RafaelNH/pyAFQ
def _bundles(row,
             wm_labels,
             odf_model="DTI",
             directions="det",
             force_recompute=False):
    bundles_file = _get_fname(row,
                              '%s_%s_bundles.trk' % (odf_model, directions))
    if not op.exists(bundles_file) or force_recompute:
        streamlines_file = _streamlines(row,
                                        wm_labels,
                                        odf_model=odf_model,
                                        directions=directions,
                                        force_recompute=force_recompute)
        tg = nib.streamlines.load(streamlines_file).tractogram
        sl = tg.apply_affine(np.linalg.inv(row['dwi_affine'])).streamlines
        bundle_dict = make_bundle_dict()
        reg_template = dpd.read_mni_template()
        mapping = reg.read_mapping(_mapping(row), row['dwi_file'],
                                   reg_template)
        bundles = seg.segment(row['dwi_file'],
                              row['bval_file'],
                              row['bvec_file'],
                              sl,
                              bundle_dict,
                              reg_template=reg_template,
                              mapping=mapping)
        tgram = _tgramer(bundles, bundle_dict, row['dwi_affine'])
        nib.streamlines.save(tgram, bundles_file)
    return bundles_file
예제 #3
0
파일: api.py 프로젝트: jyeatman/pyAFQ
def _mapping(row, force_recompute=False):
    mapping_file = _get_fname(row, '_mapping.nii.gz')
    if not op.exists(mapping_file) or force_recompute:
        gtab = row['gtab']
        reg_template = dpd.read_mni_template()
        mapping = reg.syn_register_dwi(row['dwi_file'], gtab,
                                       template=reg_template)

        reg.write_mapping(mapping, mapping_file)
    return mapping_file
예제 #4
0
파일: api.py 프로젝트: RafaelNH/pyAFQ
def _mapping(row, force_recompute=False):
    mapping_file = _get_fname(row, '_mapping.nii.gz')
    if not op.exists(mapping_file) or force_recompute:
        gtab = row['gtab']
        reg_template = dpd.read_mni_template()
        mapping = reg.syn_register_dwi(row['dwi_file'],
                                       gtab,
                                       template=reg_template)

        reg.write_mapping(mapping, mapping_file)
    return mapping_file
예제 #5
0
def main():
    with open('config.json') as config_json:
        config = json.load(config_json)

    data_file = str(config['data_file'])
    data_bval = str(config['data_bval'])
    data_bvec = str(config['data_bvec'])

    img = nib.load(data_file)

    print("Calculating DTI...")
    if not op.exists('./dti_FA.nii.gz'):
        dti_params = dti.fit_dti(data_file, data_bval, data_bvec, out_dir='.')
    else:
        dti_params = {'FA': './dti_FA.nii.gz', 'params': './dti_params.nii.gz'}

    tg = nib.streamlines.load('csa_prob.trk').tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

    # Use only a small portion of the streamlines, for expedience:
    streamlines = streamlines[::100]

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [
                    templates[name + '_roi1' + hemi],
                    templates[name + '_roi1' + hemi]
                ],
                'rules': [True, True]
            }

    print("Registering to template...")
    MNI_T2_img = dpd.read_mni_template()
    bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
    gtab = gradient_table(bvals, bvecs, b0_threshold=100)
    mapping = reg.syn_register_dwi(data_file, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')

    print("Segmenting fiber groups...")
    fiber_groups = seg.segment(data_file,
                               data_bval,
                               data_bvec,
                               streamlines,
                               bundles,
                               reg_template=MNI_T2_img,
                               mapping=mapping,
                               as_generator=False,
                               affine=img.affine)
    """
예제 #6
0
파일: api.py 프로젝트: akeshavan/pyAFQ
def _reg_prealign(row, force_recompute=False):
    prealign_file = _get_fname(row, '_reg_prealign.npy')
    if not op.exists(prealign_file) or force_recompute:
        moving = nib.load(_b0(row, force_recompute=force_recompute))
        static = dpd.read_mni_template()
        moving_data = moving.get_data()
        moving_affine = moving.affine
        static_data = static.get_data()
        static_affine = static.affine
        _, aff = reg.affine_registration(moving_data, static_data,
                                         moving_affine, static_affine)
        np.save(prealign_file, aff)
    return prealign_file
예제 #7
0
파일: api.py 프로젝트: Anneef/pyAFQ-1
 def _reg_prealign(self, row):
     prealign_file = self._get_fname(row,
                                     '_prealign_from-DWI_to-MNI_xfm.npy')
     if self.force_recompute or not op.exists(prealign_file):
         moving = nib.load(self._b0(row))
         static = dpd.read_mni_template()
         moving_data = moving.get_fdata()
         moving_affine = moving.affine
         static_data = static.get_fdata()
         static_affine = static.affine
         _, aff = reg.affine_registration(moving_data, static_data,
                                          moving_affine, static_affine)
         np.save(prealign_file, aff)
         meta_fname = self._get_fname(row,
                                      '_prealign_from-DWI_to-MNI_xfm.json')
         meta = dict(type="rigid")
         afd.write_json(meta_fname, meta)
     return prealign_file
예제 #8
0
    def prepare_map(self, mapping=None, reg_prealign=None, reg_template=None):
        """
        Set mapping between DWI space and a template.

        Parameters
        ----------
        mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional.
            A mapping between DWI space and a template.
            If None, mapping will be registered from data used in prepare_img.
            Default: None.

        reg_template : str or nib.Nifti1Image, optional.
            Template to use for registration (defaults to the MNI T2)
            Default: None.

        reg_prealign : array, optional.
            The linear transformation to be applied to align input images to
            the reference space before warping under the deformation field.
            Default: None.
        """
        if reg_template is None:
            reg_template = dpd.read_mni_template()

        if mapping is None:
            gtab = dpg.gradient_table(self.fbval, self.fbvec)
            self.mapping = reg.syn_register_dwi(self.fdata, gtab)[1]
        elif isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
            if reg_prealign is None:
                reg_prealign = np.eye(4)
            if self.img is None:
                self.img, _, _, _ = \
                    ut.prepare_data(self.fdata,
                                    self.fbval,
                                    self.fbvec,
                                    b0_threshold=self.b0_threshold)
            self.mapping = reg.read_mapping(
                mapping,
                self.img,
                reg_template,
                prealign=np.linalg.inv(reg_prealign))
            # replaced the pre_align with its inverse _aNNe
            # that fixed the warped rois and the warped prob maps which were wrong otherwise
        else:
            self.mapping = mapping
예제 #9
0
def setup_module():
    global subset_b0, subset_dwi_data, subset_t2, subset_b0_img, \
           subset_t2_img, gtab, hardi_affine, MNI_T2_affine
    MNI_T2 = dpd.read_mni_template()
    hardi_img, gtab = dpd.read_stanford_hardi()
    MNI_T2_data = MNI_T2.get_fdata()
    MNI_T2_affine = MNI_T2.affine
    hardi_data = hardi_img.get_fdata()
    hardi_affine = hardi_img.affine
    b0 = hardi_data[..., gtab.b0s_mask]
    mean_b0 = np.mean(b0, -1)

    # We select some arbitrary chunk of data so this goes quicker:
    subset_b0 = mean_b0[40:45, 40:45, 40:45]
    subset_dwi_data = nib.Nifti1Image(hardi_data[40:45, 40:45, 40:45],
                                      hardi_affine)
    subset_t2 = MNI_T2_data[40:50, 40:50, 40:50]
    subset_b0_img = nib.Nifti1Image(subset_b0, hardi_affine)
    subset_t2_img = nib.Nifti1Image(subset_t2, MNI_T2_affine)
예제 #10
0
def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs):
    """
    Register DWI data to a template.

    Parameters
    -----------
    dwi : nifti image or str
        Image containing DWI data, or full path to a nifti file with DWI.
    gtab : GradientTable or list of strings
        The gradients associated with the DWI data, or a string with [fbcal, ]
    template : nifti image or str, optional

    syn_kwargs : key-word arguments for :func:`syn_registration`

    Returns
    -------
    DiffeomorphicMap object
    """
    if template is None:
        template = dpd.read_mni_template()
    if isinstance(template, str):
        template = nib.load(template)

    template_data = template.get_fdata()
    template_affine = template.affine

    if isinstance(dwi, str):
        dwi = nib.load(dwi)

    if not isinstance(gtab, dpg.GradientTable):
        gtab = dpg.gradient_table(*gtab)

    dwi_affine = dwi.affine
    dwi_data = dwi.get_fdata()
    mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1)
    warped_b0, mapping = syn_registration(mean_b0,
                                          template_data,
                                          moving_affine=dwi_affine,
                                          static_affine=template_affine,
                                          **syn_kwargs)
    return warped_b0, mapping
예제 #11
0
def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs):
    """
    Parameters
    -----------
    dwi : nifti image or str
        Image containing DWI data, or full path to a nifti file with DWI.
    gtab : GradientTable or list of strings
        The gradients associated with the DWI data, or a string with [fbcal, ]
    template : nifti image or str, optional

    syn_kwargs : key-word arguments for :func:`syn_registration`

    Returns
    -------
    DiffeomorphicMap object
    """
    if template is None:
        template = dpd.read_mni_template()
    if isinstance(template, str):
        template = nib.load(template)

    template_data = template.get_data()
    template_affine = template.get_affine()

    if isinstance(dwi, str):
        dwi = nib.load(dwi)

    if not isinstance(gtab, dpg.GradientTable):
        gtab = dpg.gradient_table(*gtab)

    dwi_affine = dwi.get_affine()
    dwi_data = dwi.get_data()
    mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1)
    warped_b0, mapping = syn_registration(mean_b0, template_data,
                                          moving_affine=dwi_affine,
                                          static_affine=template_affine,
                                          **syn_kwargs)
    return mapping
예제 #12
0
def main():
	with open('config.json') as config_json:
	    config = json.load(config_json)
	
	#Paths to data
	
	data_file = str(config['data_file'])
	data_bval = str(config['data_bval'])
	data_bvec = str(config['data_bvec'])
	
	img = nib.load(data_file)

	print('Loaded Data')	

	print('Calculating DTI')
	if not op.exists('./dti_FA.nii.gz'):
    		dti_params = dti.fit_dti(data_file, data_bval, data_bvec,out_dir='.')
	else:
    		dti_params = {'FA': './dti_FA.nii.gz','params': './dti_params.nii.gz'}
	
	
	tg = nib.streamlines.load(str(config['tck_data'])).tractogram	
        #cannot remove inv, affine
	streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines
        #streamlines = tg.streamlines       
        print('Loaded streamlines')
        

	# Use only a small portion of the streamlines, for expedience:
	streamlines = streamlines[::100]

	templates = afd.read_templates()
	bundle_names = ["CST", "ILF"]

	bundles = {}
	for name in bundle_names:
	    for hemi in ['_R', '_L']:
		bundles[name + hemi] = {'ROIs': [templates[name + '_roi1' + hemi],
			                         templates[name + '_roi1' + hemi]],
			                'rules': [True, True]}
	print('Set Bundles')
	MNI_T2_img = dpd.read_mni_template()
	print("Registering to template...")
	bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
	if not op.exists('mapping.nii.gz'):
		#bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
        	gtab = gradient_table(bvals, bvecs)
    	    	mapping = reg.syn_register_dwi(data_file, gtab)
            	reg.write_mapping(mapping, './mapping.nii.gz')
	else:
    	    	mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
	

	print("Segmenting fiber groups...")
	fiber_groups = seg.segment(data_file,
			           data_bval,
			           data_bvec,
			           streamlines,
			           bundles,
			           reg_template=MNI_T2_img,
			           mapping=mapping,
			           as_generator=False,
				   affine=img.affine)
	
	path = os.getcwd() + '/tract1/'
        if not os.path.exists(path):
        	os.makedirs(path)
	
        print('Creating tck files')
	for fg in fiber_groups:
	    	streamlines = fiber_groups[fg]
		fname = fg + ".tck"
		trg = nib.streamlines.Tractogram(streamlines, affine_to_rasmm=img.affine)
    		nib.streamlines.save(trg,path+fname)
   		print('Finished segment')
예제 #13
0
def segment(fdata, fbval, fbvec, streamlines, bundle_dict, mapping,
            reg_prealign=None, b0_threshold=0, reg_template=None,
            prob_threshold=0):
    """
    Segment streamlines into bundles based on inclusion ROIs.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundle_dict: dict
        The format is something like::

            {'name': {'ROIs':[img1, img2],
            'rules':[True, True]},
            'prob_map': img3,
            'cross_midline': False}

    mapping : a DiffeomorphicMapping object
        Used to align the ROIs to the data.

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate
        this.

    prob_threshold : float.
        Initial cleaning of fiber groups is done using probability maps from
        [Hua2008]_. Here, we choose an average probability that needs to be
        exceeded for an individual streamline to be retained. Default: 0.

    References
    ----------
    .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
       Tract probability maps in stereotaxic spaces: analyses of white
       matter anatomy and tract-specific quantification. Neuroimage 39:
       336-347
    """
    img, _, gtab, _ = ut.prepare_data(fdata, fbval, fbvec,
                                      b0_threshold=b0_threshold)

    tol = dts.dist_to_corner(img.affine)

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    # Classify the streamlines and split those that: 1) cross the
    # midline, and 2) pass under 10 mm below the mid-point of their
    # representation in the template space:
    xform_sl, crosses = split_streamlines(streamlines, img)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        if reg_prealign is None:
            reg_prealign = np.eye(4)
        mapping = reg.read_mapping(mapping, img, reg_template,
                                   prealign=reg_prealign)

    fiber_probabilities = np.zeros((len(xform_sl), len(bundle_dict)))

    # For expedience, we approximate each streamline as a 100 point curve:
    fgarray = _resample_bundle(xform_sl, 100)
    streamlines_in_bundles = np.zeros((len(xform_sl), len(bundle_dict)))
    min_dist_coords = np.zeros((len(xform_sl), len(bundle_dict), 2))

    fiber_groups = {}

    for bundle_idx, bundle in enumerate(bundle_dict):
        rules = bundle_dict[bundle]['rules']
        include_rois = []
        exclude_rois = []
        for rule_idx, rule in enumerate(rules):
            roi = bundle_dict[bundle]['ROIs'][rule_idx]
            if not isinstance(roi, np.ndarray):
                roi = roi.get_data()
            warped_roi = auv.patch_up_roi(
                (mapping.transform_inverse(
                    roi,
                    interpolation='linear')) > 0)

            if rule:
                # include ROI:
                include_rois.append(np.array(np.where(warped_roi)).T)
            else:
                # Exclude ROI:
                exclude_rois.append(np.array(np.where(warped_roi)).T)

        crosses_midline = bundle_dict[bundle]['cross_midline']

        # The probability map if doesn't exist is all ones with the same
        # shape as the ROIs:
        prob_map = bundle_dict[bundle].get('prob_map', np.ones(roi.shape))

        if not isinstance(prob_map, np.ndarray):
            prob_map = prob_map.get_data()
        warped_prob_map = mapping.transform_inverse(prob_map,
                                                    interpolation='nearest')
        fiber_probabilities = dts.values_from_volume(warped_prob_map,
                                                     fgarray)
        fiber_probabilities = np.mean(fiber_probabilities, -1)

        for sl_idx, sl in enumerate(xform_sl):
            if fiber_probabilities[sl_idx] > prob_threshold:
                if crosses_midline is not None:
                    if crosses[sl_idx]:
                        # This means that the streamline does
                        # cross the midline:
                        if crosses_midline:
                            # This is what we want, keep going
                            pass
                        else:
                            # This is not what we want, skip to next streamline
                            continue

                is_close, dist = _check_sl_with_inclusion(sl, include_rois,
                                                          tol)
                if is_close:
                    is_far = _check_sl_with_exclusion(sl, exclude_rois,
                                                      tol)
                    if is_far:
                        min_dist_coords[sl_idx, bundle_idx, 0] =\
                            np.argmin(dist[0], 0)[0]
                        min_dist_coords[sl_idx, bundle_idx, 1] =\
                            np.argmin(dist[1], 0)[0]
                        streamlines_in_bundles[sl_idx, bundle_idx] =\
                            fiber_probabilities[sl_idx]

    # Eliminate any fibers not selected using the plane ROIs:
    possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
    xform_sl = xform_sl[possible_fibers]
    streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
    min_dist_coords = min_dist_coords[possible_fibers]
    bundle_choice = np.argmax(streamlines_in_bundles, -1)

    # We do another round through, so that we can orient all the
    # streamlines within a bundle in the same orientation with respect to
    # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
    # to ROI1).
    for bundle_idx, bundle in enumerate(bundle_dict):
        select_idx = np.where(bundle_choice == bundle_idx)
        # Use a list here, because Streamlines don't support item assignment:
        select_sl = list(xform_sl[select_idx])
        if len(select_sl) == 0:
            fiber_groups[bundle] = dts.Streamlines([])
            # There's nothing here, move to the next bundle:
            continue

        # Sub-sample min_dist_coords:
        min_dist_coords_bundle = min_dist_coords[select_idx]
        for idx in range(len(select_sl)):
            min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
            min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
            if min0 > min1:
                select_sl[idx] = select_sl[idx][::-1]
        # Set this to nibabel.Streamlines object for output:
        select_sl = dts.Streamlines(select_sl)
        fiber_groups[bundle] = select_sl

    return fiber_groups
예제 #14
0
파일: api.py 프로젝트: Anneef/pyAFQ-1
    def __init__(self,
                 dmriprep_path,
                 sub_prefix="sub",
                 dwi_folder="dwi",
                 dwi_file="*dwi",
                 anat_folder="anat",
                 anat_file="*T1w*",
                 seg_file='*aparc+aseg*',
                 b0_threshold=0,
                 bundle_names=BUNDLES,
                 dask_it=False,
                 force_recompute=False,
                 reg_template=None,
                 scalars=["dti_fa", "dti_md"],
                 wm_labels=[250, 251, 252, 253, 254, 255, 41, 2, 16, 77],
                 use_prealign=True,
                 tracking_params=None,
                 segmentation_params=None,
                 clean_params=None):
        """

        dmriprep_path: str
            The path to the preprocessed diffusion data.

        seg_algo : str
            Which algorithm to use for segmentation.
            Can be one of: {"afq", "reco"}

        b0_threshold : int, optional
            The value of b under which it is considered to be b0. Default: 0.

        odf_model : string, optional
            Which model to use for determining directions in tractography
            {"DTI", "DKI", "CSD"}. Default: "DTI"

        directions : string, optional
            How to select directions for tracking (deterministic or
            probablistic) {"det", "prob"}. Default: "det".

        dask_it : bool, optional
            Whether to use a dask DataFrame object

        force_recompute : bool, optional
            Whether to recompute or ignore existing derivatives.
            This parameter can be turned on/off dynamically.
            Default: False

        wm_labels : list, optional
            A list of the labels of the white matter in the segmentation file
            used. Default: the white matter values for the segmentation
            provided with the HCP data, including labels for midbrain:
            [250, 251, 252, 253, 254, 255, 41, 2, 16, 77].

        use_prealign : bool, optional
            Whether to perform pre-alignment before perforiming the
            diffeomorphic mapping in registration. Default: True

        segmentation_params : dict, optional
            The parameters for segmentation. Default: use the default behavior
            of the seg.Segmentation object

        tracking_params: dict, optional
            The parameters for tracking. Default: use the default behavior of
            the aft.track function.

        clean_params: dict, optional
            The parameters for cleaning. Default: use the default behavior of
            the seg.clean_bundle function.
        """

        self.force_recompute = force_recompute

        self.wm_labels = wm_labels
        self.use_prealign = use_prealign

        self.scalars = scalars

        default_tracking_params = get_default_args(aft.track)
        # Replace the defaults only for kwargs for which a non-default value was
        # given:
        if tracking_params is not None:
            for k in tracking_params:
                default_tracking_params[k] = tracking_params[k]

        self.tracking_params = default_tracking_params

        default_seg_params = get_default_args(seg.Segmentation.__init__)
        if segmentation_params is not None:
            for k in segmentation_params:
                default_seg_params[k] = segmentation_params[k]

        self.segmentation_params = default_seg_params
        self.seg_algo = self.segmentation_params["seg_algo"].lower()
        self.bundle_dict = make_bundle_dict(bundle_names=bundle_names,
                                            seg_algo=self.seg_algo,
                                            resample_to=reg_template)

        default_clean_params = get_default_args(seg.clean_bundle)
        if clean_params is not None:
            for k in clean_params:
                default_clean_params[k] = clean_params[k]

        self.clean_params = default_clean_params

        if reg_template is None:
            self.reg_template = dpd.read_mni_template()
        else:
            if not isinstance(reg_template, nib.Nifti1Image):
                reg_template = nib.load(reg_template)
            self.reg_template = reg_template
        # This is the place in which each subject's full data lives
        self.dmriprep_dirs = glob.glob(op.join(dmriprep_path,
                                               f"{sub_prefix}*"))

        # This is where all the outputs will go:
        self.afq_dir = op.join(op.join(*PurePath(dmriprep_path).parts[:-1]),
                               'afq')

        os.makedirs(self.afq_dir, exist_ok=True)

        self.subjects = [op.split(p)[-1] for p in self.dmriprep_dirs]

        sub_list = []
        sess_list = []
        dwi_file_list = []
        bvec_file_list = []
        bval_file_list = []
        anat_file_list = []
        seg_file_list = []
        results_dir_list = []
        for subject, sub_dir in zip(self.subjects, self.dmriprep_dirs):
            sessions = glob.glob(op.join(sub_dir, '*'))
            for sess in sessions:
                results_dir_list.append(
                    op.join(self.afq_dir, subject,
                            PurePath(sess).parts[-1]))

                os.makedirs(results_dir_list[-1], exist_ok=True)

                dwi_file_list.append(
                    glob.glob(f"{sess}/{dwi_folder}/{dwi_file}.nii.gz")[0])

                bvec_file_list.append(
                    glob.glob(f"{sess}/{dwi_folder}/{dwi_file}.bvec*")[0])

                bval_file_list.append(
                    glob.glob(f"{sess}/{dwi_folder}/{dwi_file}.bval*")[0])

                # The following two may or may not exist:
                this_anat_file = glob.glob(
                    op.join(sub_dir,
                            (f"{sess}/{anat_folder}/{anat_file}.nii.gz")))
                if len(this_anat_file):
                    anat_file_list.append(this_anat_file[0])

                this_seg_file = glob.glob(
                    op.join(sub_dir,
                            (f"{sess}/{anat_folder}/{seg_file}.nii.gz")))
                if len(this_seg_file):
                    seg_file_list.append(this_seg_file[0])

                sub_list.append(subject)
                sess_list.append(sess)
        self.data_frame = pd.DataFrame(
            dict(subject=sub_list,
                 dwi_file=dwi_file_list,
                 bvec_file=bvec_file_list,
                 bval_file=bval_file_list,
                 sess=sess_list,
                 results_dir=results_dir_list))
        # Add these if they exist:
        if len(seg_file_list):
            self.data_frame['seg_file'] = seg_file_list
        if len(anat_file_list):
            self.data_frame['anat_file'] = anat_file_list

        if dask_it:
            self.data_frame = ddf.from_pandas(self.data_frame,
                                              npartitions=len(sub_list))
        self.set_gtab(b0_threshold)
        self.set_dwi_affine()
        self.set_dwi_img()
예제 #15
0
    import mayavi
    import mayavi.mlab as mlab
    from tvtk.tools import visual
    from tvtk.api import tvtk

ni, gtab = dpd.read_stanford_hardi()
hardi_data = ni.get_data()
hardi_affine = ni.get_affine()
b0 = hardi_data[..., gtab.b0s_mask]
mean_b0 = np.mean(b0, -1)

ni_b0 = nib.Nifti1Image(mean_b0, hardi_affine)
ni_b0.to_filename('mean_b0.nii')
plt.matshow(mean_b0[:, :, mean_b0.shape[-1] // 2], cmap=cm.bone)

MNI_T2 = dpd.read_mni_template()
MNI_T2_data = MNI_T2.get_data()
MNI_T2_affine = MNI_T2.get_affine()

level_iters = [10, 10, 5]
dim = 3
metric = CCMetric(dim)
sdr = SymmetricDiffeomorphicRegistration(metric, level_iters, step_length=0.25)
sdr.verbosity = VerbosityLevels.DIAGNOSE
mapping = sdr.optimize(MNI_T2_data, mean_b0, MNI_T2_affine, hardi_affine)
warped_b0 = mapping.transform(mean_b0)
plt.matshow(warped_b0[:, :, warped_b0.shape[-1] // 2], cmap=cm.bone)
plt.matshow(MNI_T2_data[:, :, MNI_T2_data.shape[-1] // 2], cmap=cm.bone)

new_ni = nib.Nifti1Image(warped_b0, MNI_T2_affine)
new_ni.to_filename('./warped_b0.nii.gz')
예제 #16
0
import nibabel as nib
import nibabel.tmpdirs as nbtmp

import dipy.data as dpd
import dipy.core.gradients as dpg

from AFQ.registration import (syn_registration, register_series, register_dwi,
                              c_of_mass, translation, rigid, affine,
                              streamline_registration, write_mapping,
                              read_mapping, syn_register_dwi, DiffeomorphicMap)

from dipy.tracking.utils import move_streamlines

from AFQ.utils.streamlines import write_trk

MNI_T2 = dpd.read_mni_template()
hardi_img, gtab = dpd.read_stanford_hardi()
MNI_T2_data = MNI_T2.get_data()
MNI_T2_affine = MNI_T2.get_affine()
hardi_data = hardi_img.get_data()
hardi_affine = hardi_img.get_affine()
b0 = hardi_data[..., gtab.b0s_mask]
mean_b0 = np.mean(b0, -1)

# We select some arbitrary chunk of data so this goes quicker:
subset_b0 = mean_b0[40:50, 40:50, 40:50]
subset_dwi_data = nib.Nifti1Image(hardi_data[40:50, 40:50, 40:50],
                                  hardi_affine)
subset_t2 = MNI_T2_data[40:60, 40:60, 40:60]
subset_b0_img = nib.Nifti1Image(subset_b0, hardi_affine)
subset_t2_img = nib.Nifti1Image(subset_t2, MNI_T2_affine)
예제 #17
0
파일: pipeline.py 프로젝트: jyeatman/pyAFQ
else:
    brain_mas = nib.load('./brain_mask.nii.gz').get_data().astype(bool)

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
    dti_params = dti.fit_dti(fdata, fbval, fbvec,
                             out_dir='.', mask=brain_mask)
else:
    dti_params = {'FA': './dti_FA.nii.gz',
                  'MD': './dti_MD.nii.gz',
                  'RD': './dti_RD.nii.gz',
                  'AD': './dti_AD.nii.gz',
                  'params': './dti_params.nii.gz'}

print("Registering to template...")
MNI_T2_img = dpd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(fbval, fbvec)
    mapping = reg.syn_register_dwi(fdata, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    FA = nib.load(dti_params["FA"]).get_data()
    wm_mask = np.zeros_like(FA)
    wm_mask[FA > 0.2] = 1
    step_size = 1
    min_length_mm = 50
예제 #18
0
def segment(fdata,
            fbval,
            fbvec,
            streamlines,
            bundles,
            reg_template=None,
            mapping=None,
            as_generator=True,
            clip_to_roi=True,
            **reg_kwargs):
    """
    Segment streamlines into bundles.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundles: dict
        The format is something like::

             {'name': {'ROIs':[img, img], 'rules':[True, True]}}

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate this.

    as_generator : bool, optional
        Whether to generate the streamlines here, or return generators.
        Default: True.

    clip_to_roi : bool, optional
        Whether to clip the streamlines between the ROIs
    """
    img, data, gtab, mask = ut.prepare_data(fdata, fbval, fbvec)
    xform_sl = [
        s for s in dtu.move_streamlines(streamlines, np.linalg.inv(img.affine))
    ]

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata,
                                       gtab,
                                       template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_groups = {}
    for bundle in bundles:
        select_sl = xform_sl
        for ROI, rule in zip(bundles[bundle]['ROIs'],
                             bundles[bundle]['rules']):
            data = ROI.get_data()
            warped_ROI = patch_up_roi(
                mapping.transform_inverse(data, interpolation='nearest'))
            # This function requires lists as inputs:
            select_sl = dts.select_by_rois(select_sl,
                                           [warped_ROI.astype(bool)], [rule])
        # Next, we reorient each streamline according to an ARBITRARY, but
        # CONSISTENT order. To do this, we use the first ROI for which the rule
        # is True as the first one to pass through, and the last ROI for which
        # the rule is True as the last one to pass through:

        # Indices where the 'rule' is True:
        idx = np.where(bundles[bundle]['rules'])

        orient_ROIs = [
            bundles[bundle]['ROIs'][idx[0][0]],
            bundles[bundle]['ROIs'][idx[0][-1]]
        ]

        select_sl = dts.orient_by_rois(select_sl,
                                       orient_ROIs[0].get_data(),
                                       orient_ROIs[1].get_data(),
                                       as_generator=True)

        #  XXX Implement clipping to the ROIs
        #  if clip_to_roi:
        #    dts.clip()

        if as_generator:
            fiber_groups[bundle] = select_sl
        else:
            fiber_groups[bundle] = list(select_sl)

    return fiber_groups
예제 #19
0
파일: api.py 프로젝트: akeshavan/pyAFQ
    def __init__(self,
                 dmriprep_path,
                 seg_algo="planes",
                 sub_prefix="sub",
                 dwi_folder="dwi",
                 dwi_file="*dwi",
                 anat_folder="anat",
                 anat_file="*T1w*",
                 seg_file='*aparc+aseg*',
                 b0_threshold=0,
                 odf_model="CSD",
                 directions="det",
                 n_seeds=2,
                 random_seeds=False,
                 bundle_names=BUNDLES,
                 dask_it=False,
                 force_recompute=False,
                 reg_template=None,
                 wm_labels=[250, 251, 252, 253, 254, 255, 41, 2, 16, 77]):
        """

        dmriprep_path: str
            The path to the preprocessed diffusion data.

        seg_algo : str
            Which algorithm to use for segmentation.
            Can be one of: {"planes", "recobundles"}

        b0_threshold : int, optional
            The value of b under which it is considered to be b0. Default: 0.

        odf_model : string, optional
            Which model to use for determining directions in tractography
            {"DTI", "DKI", "CSD"}. Default: "DTI"

        directions : string, optional
            How to select directions for tracking (deterministic or
            probablistic) {"det", "prob"}. Default: "det".

        dask_it : bool, optional
            Whether to use a dask DataFrame object

        force_recompute : bool, optional
            Whether to ignore previous results, and recompute all, or not.

        wm_labels : list, optional
            A list of the labels of the white matter in the segmentation file
            used. Default: the white matter values for the segmentation
            provided with the HCP data, including labels for midbraing:
            [250, 251, 252, 253, 254, 255, 41, 2, 16, 77].
        """
        self.directions = directions
        self.odf_model = odf_model
        self.bundle_dict = make_bundle_dict(bundle_names=bundle_names,
                                            seg_algo=seg_algo)
        self.seg_algo = seg_algo
        self.force_recompute = force_recompute
        self.wm_labels = wm_labels
        self.n_seeds = n_seeds
        self.random_seeds = random_seeds
        if reg_template is None:
            self.reg_template = dpd.read_mni_template()
        else:
            if not isinstance(reg_template, nib.Nifti1Image):
                reg_template = nib.load(reg_template)
            self.reg_template = reg_template
        # This is the place in which each subject's full data lives
        self.dmriprep_dirs = glob.glob(
            op.join(dmriprep_path, '%s*' % sub_prefix))

        # This is where all the outputs will go:
        self.afq_dir = op.join(op.join(*PurePath(dmriprep_path).parts[:-1]),
                               'afq')

        os.makedirs(self.afq_dir, exist_ok=True)

        self.subjects = [op.split(p)[-1] for p in self.dmriprep_dirs]

        sub_list = []
        sess_list = []
        dwi_file_list = []
        bvec_file_list = []
        bval_file_list = []
        anat_file_list = []
        seg_file_list = []
        results_dir_list = []
        for subject, sub_dir in zip(self.subjects, self.dmriprep_dirs):
            sessions = glob.glob(op.join(sub_dir, '*'))
            for sess in sessions:
                results_dir_list.append(
                    op.join(self.afq_dir, subject,
                            PurePath(sess).parts[-1]))

                os.makedirs(results_dir_list[-1], exist_ok=True)

                dwi_file_list.append(
                    glob.glob(
                        op.join(sub_dir, ('%s/%s/%s.nii.gz' %
                                          (sess, dwi_folder, dwi_file))))[0])

                bvec_file_list.append(
                    glob.glob(
                        op.join(sub_dir, ('%s/%s/%s.bvec*' %
                                          (sess, dwi_folder, dwi_file))))[0])

                bval_file_list.append(
                    glob.glob(
                        op.join(sub_dir, ('%s/%s/%s.bval*' %
                                          (sess, dwi_folder, dwi_file))))[0])

                # The following two may or may not exist:
                this_anat_file = glob.glob(
                    op.join(sub_dir, ('%s/%s/%s.nii.gz' %
                                      (sess, anat_folder, anat_file))))
                if len(this_anat_file):
                    anat_file_list.append(this_anat_file[0])

                this_seg_file = glob.glob(
                    op.join(sub_dir, ('%s/%s/%s.nii.gz' %
                                      (sess, anat_folder, seg_file))))
                if len(this_seg_file):
                    seg_file_list.append(this_seg_file[0])

                sub_list.append(subject)
                sess_list.append(sess)
        self.data_frame = pd.DataFrame(
            dict(subject=sub_list,
                 dwi_file=dwi_file_list,
                 bvec_file=bvec_file_list,
                 bval_file=bval_file_list,
                 sess=sess_list,
                 results_dir=results_dir_list))
        # Add these if they exist:
        if len(seg_file_list):
            self.data_frame['seg_file'] = seg_file_list
        if len(anat_file_list):
            self.data_frame['anat_file'] = anat_file_list

        if dask_it:
            self.data_frame = ddf.from_pandas(self.data_frame,
                                              npartitions=len(sub_list))
        self.set_gtab(b0_threshold)
        self.set_dwi_affine()
예제 #20
0
def segment(fdata, fbval, fbvec, streamlines, bundles,
            reg_template=None, mapping=None, as_generator=True, **reg_kwargs):
    """

    generate : bool
        Whether to generate the streamlines here, or return generators.

    reg_template : template to use for registration (defaults to the MNI T2)

    bundles: dict
        The format is something like::

             {'name': {'ROIs':[img, img], 'rules':[True, True]}}


    """
    img, data, gtab, mask = ut.prepare_data(fdata, fbval, fbvec)
    xform_sl = [s for s in dtu.move_streamlines(streamlines,
                                                np.linalg.inv(img.affine))]

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata, gtab, template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_groups = {}
    for bundle in bundles:
        select_sl = xform_sl
        for ROI, rule in zip(bundles[bundle]['ROIs'],
                             bundles[bundle]['rules']):
            data = ROI.get_data()
            warped_ROI = patch_up_roi(mapping.transform_inverse(
                data,
                interpolation='nearest'))
            # This function requires lists as inputs:
            select_sl = dts.select_by_rois(select_sl,
                                           [warped_ROI.astype(bool)],
                                           [rule])
        # Next, we reorient each streamline according to an ARBITRARY, but
        # CONSISTENT order. To do this, we use the first ROI for which the rule
        # is True as the first one to pass through, and the last ROI for which
        # the rule is True as the last one to pass through:

        # Indices where the 'rule' is True:
        idx = np.where(bundles[bundle]['rules'])

        orient_ROIs = [bundles[bundle]['ROIs'][idx[0][0]],
                       bundles[bundle]['ROIs'][idx[0][-1]]]

        select_sl = dts.orient_by_rois(select_sl,
                                       orient_ROIs[0].get_data(),
                                       orient_ROIs[1].get_data(),
                                       in_place=True)
        if as_generator:
            fiber_groups[bundle] = select_sl
        else:
            fiber_groups[bundle] = list(select_sl)

    return fiber_groups
예제 #21
0
#os.mkdir("wmc/surfaces")

# open configurable inputs
with open('config.json') as config_f:
    config = json.load(config_f)
    dwi = config["dwi"]
    bvals = config["bvals"]
    bvecs = config["bvecs"]
    track = config["track"]

# load dwi data and generate gradient table
dwi_img = nib.load(dwi)
gtab = dpg.gradient_table(bvals, bvecs)

# load MNI template and syn register dwi data to MNI
MNI_T2_img = dpd.read_mni_template()
warped_hardi, mapping = reg.syn_register_dwi(dwi, gtab)

# load tractogram
tg = load_tractogram(track, dwi_img)
#tg_acpc = transform_streamlines(tg.streamlines,dwi_img.get_affine())

# download and load waypoint ROIs and make bundle dictionary
bundles = api.make_bundle_dict(resample_to=MNI_T2_img)
bundle_names = list(bundles.keys())

print(f"Space before segmentation: {tg.space}")

# initialize segmentation and segment major fiber groups
print("running AFQ segmentation")
segmentation = seg.Segmentation(return_idx=True)
예제 #22
0
def register_dwi_to_template(dwi,
                             gtab,
                             dwi_affine=None,
                             template=None,
                             template_affine=None,
                             reg_method="syn",
                             **reg_kwargs):
    """
    Register DWI data to a template through the B0 volumes.

    Parameters
    -----------
    dwi : 4D array, nifti image or str
        Containing the DWI data, or full path to a nifti file with DWI.
    gtab : GradientTable or sequence of strings
        The gradients associated with the DWI data, or a sequence with
        (fbval, fbvec), full paths to bvals and bvecs files.
    dwi_affine : 4x4 array, optional
        An affine transformation associated with the DWI. Required if data
        is provided as an array. If provided together with nifti/path,
        will over-ride the affine that is in the nifti.
    template : 3D array, nifti image or str
        Containing the data for the template, or full path to a nifti file
        with the template data.
    template_affine : 4x4 array, optional
        An affine transformation associated with the template. Required if data
        is provided as an array. If provided together with nifti/path,
        will over-ride the affine that is in the nifti.

    reg_method : str,
        One of "syn" or "aff", which designates which registration method is
        used. Either syn, which uses the :func:`syn_registration` function
        or :func:`affine_registration` function. Default: "syn".
    reg_kwargs : key-word arguments for :func:`syn_registration` or
        :func:`affine_registration`

    Returns
    -------
    warped_b0, mapping: The fist is an array with the b0 volume warped to the
    template. If reg_method is "syn", the second is a DiffeomorphicMap class
    instance that can be used to transform between the two spaces. Otherwise,
    if reg_method is "aff", this is a 4x4 matrix encoding the affine transform.

    Notes
    -----
    This function assumes that the DWI data is already internally registered.
    See :func:`register_dwi_series`.

    """
    dwi_data, dwi_affine = read_img_arr_or_path(dwi, affine=dwi_affine)

    if template is None:
        template = dpd.read_mni_template()

    template_data, template_affine = read_img_arr_or_path(
        template, affine=template_affine)

    if not isinstance(gtab, dpg.GradientTable):
        gtab = dpg.gradient_table(*gtab)

    mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1)
    if reg_method.lower() == "syn":
        warped_b0, mapping = syn_registration(mean_b0,
                                              template_data,
                                              moving_affine=dwi_affine,
                                              static_affine=template_affine,
                                              **reg_kwargs)
    elif reg_method.lower() == "aff":
        warped_b0, mapping = affine_registration(mean_b0,
                                                 template_data,
                                                 moving_affine=dwi_affine,
                                                 static_affine=template_affine,
                                                 **reg_kwargs)
    else:
        raise ValueError("reg_method should be one of 'aff' or 'syn', but you"
                         " provided %s" % reg_method)

    return warped_b0, mapping
예제 #23
0
def segment(fdata, fbval, fbvec, streamlines, bundles,
            reg_template=None, mapping=None, prob_threshold=0,
            **reg_kwargs):
    """
    Segment streamlines into bundles based on inclusion ROIs.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundles: dict
        The format is something like::

            {'name': {'ROIs':[img1, img2],
            'rules':[True, True]},
            'prob_map': img3,
            'cross_midline': False}

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate
        this.

    prob_threshold : float.
        Initial cleaning of fiber groups is done using probability maps from
        [Hua2008]_. Here, we choose an average probability that needs to be
        exceeded for an individual streamline to be retained. Default: 0.

    References
    ----------
    .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
       Tract probability maps in stereotaxic spaces: analyses of white
       matter anatomy and tract-specific quantification. Neuroimage 39:
       336-347
    """
    img, _, gtab, _ = ut.prepare_data(fdata, fbval, fbvec)
    tol = dts.dist_to_corner(img.affine)

    xform_sl = dts.Streamlines(dtu.move_streamlines(streamlines,
                                                    np.linalg.inv(img.affine)))

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata, gtab, template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_probabilities = np.zeros((len(xform_sl), len(bundles)))

    # For expedience, we approximate each streamline as a 100 point curve:
    fgarray = _resample_bundle(xform_sl, 100)
    streamlines_in_bundles = np.zeros((len(xform_sl), len(bundles)))
    min_dist_coords = np.zeros((len(xform_sl), len(bundles), 2))

    fiber_groups = {}

    for bundle_idx, bundle in enumerate(bundles):
        # Get the ROI coordinates:
        ROI0 = bundles[bundle]['ROIs'][0]
        ROI1 = bundles[bundle]['ROIs'][1]
        if not isinstance(ROI0, np.ndarray):
            ROI0 = ROI0.get_data()

        warped_ROI0 = patch_up_roi(
            mapping.transform_inverse(
                ROI0,
                interpolation='nearest')).astype(bool)

        if not isinstance(ROI1, np.ndarray):
            ROI1 = ROI1.get_data()

        warped_ROI1 = patch_up_roi(
            mapping.transform_inverse(
                ROI1,
                interpolation='nearest')).astype(bool)

        roi_coords0 = np.array(np.where(warped_ROI0)).T
        roi_coords1 = np.array(np.where(warped_ROI1)).T

        crosses_midline = bundles[bundle]['cross_midline']

        # The probability map if doesn't exist is all ones with the same
        # shape as the ROIs:
        prob_map = bundles[bundle].get('prob_map', np.ones(ROI0.shape))
        if not isinstance(prob_map, np.ndarray):
            prob_map = prob_map.get_data()
        warped_prob_map = mapping.transform_inverse(prob_map,
                                                    interpolation='nearest')
        fiber_probabilities = dts.values_from_volume(warped_prob_map,
                                                     fgarray)
        fiber_probabilities = np.mean(fiber_probabilities, -1)

        for sl_idx, sl in enumerate(xform_sl):
            if fiber_probabilities[sl_idx] > prob_threshold:
                if crosses_midline is not None:
                    if (np.any(sl[:, 0] > img.shape[0] // 2) and
                            np.any(sl[:, 0] < img.shape[0] // 2)):
                        # This means that the streamline does
                        # cross the midline:
                        if crosses_midline:
                            # This is what we want, keep going
                            pass
                        else:
                            # This is not what we want, skip to next streamline
                            continue
                dist0 = cdist(sl, roi_coords0, 'euclidean')
                if np.min(dist0) <= tol:
                    dist1 = cdist(sl, roi_coords1, 'euclidean')
                    if np.min(dist1) <= tol:
                        min_dist_coords[sl_idx, bundle_idx, 0] =\
                            np.argmin(dist0, 0)[0]
                        min_dist_coords[sl_idx, bundle_idx, 1] =\
                            np.argmin(dist1, 0)[0]
                        streamlines_in_bundles[sl_idx, bundle_idx] =\
                            fiber_probabilities[sl_idx]

    # Eliminate any fibers not selected using the plane ROIs:
    possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
    xform_sl = xform_sl[possible_fibers]
    streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
    min_dist_coords = min_dist_coords[possible_fibers]
    bundle_choice = np.argmax(streamlines_in_bundles, -1)

    for bundle_idx, bundle in enumerate(bundles):
        print(bundle)
        select_idx = np.where(bundle_choice == bundle_idx)
        # Use a list here, because Streamlines don't support item assignment:
        select_sl = list(xform_sl[select_idx])
        # Sub-sample min_dist_coords:
        min_dist_coords_bundle = min_dist_coords[select_idx]
        if len(select_sl) == 0:
            fiber_groups[bundle] = dts.Streamlines([])
            # There's nothing here, move to the next bundle:
            continue

        for idx in range(len(select_sl)):
            min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
            min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
            if min0 > min1:
                select_sl[idx] = select_sl[idx][::-1]
        # We'll set this to Streamlines object for the next steps (e.g.,
        # cleaning) because these objects support indexing with arrays:
        select_sl = dts.Streamlines(select_sl)
        fiber_groups[bundle] = select_sl

    return fiber_groups