示例#1
0
def probal(Threshold=.2,
           data_list=None,
           seed='.',
           one_node=False,
           two_node=False):
    time0 = time.time()
    print("begin loading data, time:", time.time() - time0)

    data = data_list['DWI']
    affine = data_list['affine']
    img = data_list['img']
    labels = data_list['labels']
    gtab = data_list['gtab']
    head_mask = data_list['head_mask']

    if type(seed) != str:
        seed_mask = seed
    else:
        seed_mask = (labels == 2) * (head_mask == 1)

    white_matter = (labels == 2) * (head_mask == 1)
    seeds = utils.seeds_from_mask(seed_mask, affine, density=1)

    print("begin reconstruction, time:", time.time() - time0)
    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
    csd_fit = csd_model.fit(data, mask=white_matter)

    csa_model = CsaOdfModel(gtab, sh_order=6)
    gfa = csa_model.fit(data, mask=white_matter).gfa
    stopping_criterion = ThresholdStoppingCriterion(gfa, Threshold)

    print("begin tracking, time:", time.time() - time0)
    fod = csd_fit.odf(small_sphere)
    pmf = fod.clip(min=0)
    prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                                    max_angle=30.,
                                                    sphere=small_sphere)
    streamline_generator = LocalTracking(prob_dg,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)

    sft = StatefulTractogram(streamlines, img, Space.RASMM)

    if one_node or two_node:
        sft.to_vox()
        streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node,
                                      two_node)
        sft = StatefulTractogram(streamlines, img, Space.VOX)
        sft._vox_to_rasmm()

    print("begin saving, time:", time.time() - time0)

    output = 'tractogram_probabilistic.trk'
    save_trk(sft, output)

    print("finished, time:", time.time() - time0)
def get_tractogram_in_voxel_space(tract_fname,
                                  ref_anat_fname,
                                  tracts_attribs={'orientation': 'unknown'},
                                  origin=Origin.NIFTI):
    if tracts_attribs['orientation'] != 'unknown':
        to_lps = tracts_attribs['orientation'] == 'LPS'
        streamlines = load_vtk_streamlines(tract_fname, to_lps)
        sft = StatefulTractogram(streamlines, ref_anat_fname, Space.RASMM)
    else:
        sft = load_tractogram(tract_fname,
                              ref_anat_fname,
                              bbox_valid_check=False,
                              trk_header_check=False)
    sft.to_vox()
    sft.to_origin(origin)
    return sft
示例#3
0
def test_fit_data():
    fdata, fbval, fbvec = dpd.get_fnames('small_25')
    fstreamlines = dpd.get_fnames('small_25_streamlines')
    gtab = grad.gradient_table(fbval, fbvec)
    ni_data = nib.load(fdata)
    data = ni_data.get_data()
    tensor_streamlines = nib.streamlines.load(fstreamlines).streamlines
    sft = StatefulTractogram(tensor_streamlines, ni_data, Space.RASMM)
    sft.to_vox()
    tensor_streamlines_vox = sft.streamlines

    life_model = life.FiberModel(gtab)
    life_fit = life_model.fit(data, tensor_streamlines_vox, np.eye(4))
    model_error = life_fit.predict() - life_fit.data
    model_rmse = np.sqrt(np.mean(model_error**2, -1))
    matlab_rmse, matlab_weights = dpd.matlab_life_results()
    # Lower error than the matlab implementation for these data:
    npt.assert_(np.median(model_rmse) < np.median(matlab_rmse))
    # And a moderate correlation with the Matlab implementation weights:
    npt.assert_(np.corrcoef(matlab_weights, life_fit.beta)[0, 1] > 0.6)
示例#4
0
def load_tractogram(filename,
                    reference,
                    to_space=Space.RASMM,
                    shifted_origin=False,
                    bbox_valid_check=True,
                    trk_header_check=True):
    """ Load the stateful tractogram from any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    filename : string
        Filename with valid extension
    reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
        trk.header (dict), or 'same' if the input is a trk file.
        Reference that provides the spatial attribute.
        Typically a nifti-related object from the native diffusion used for
        streamlines generation
    to_space : Enum (dipy.io.stateful_tractogram.Space)
        Space to which the streamlines will be transformed after loading.
    shifted_origin : bool
        Information on the position of the origin,
        False is Trackvis standard, default (center of the voxel)
        True is NIFTI standard (corner of the voxel)
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.
    trk_header_check : bool
        Verification that the reference has the same header as the spatial
        attributes as the input tractogram when a Trk is loaded

    Returns
    -------
    output : StatefulTractogram
        The tractogram to load (must have been saved properly)
    """
    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        logging.error('Output filename is not one of the supported format')
        return False

    if to_space not in Space:
        logging.error('Space MUST be one of the 3 choices (Enum)')
        return False

    if reference == 'same':
        if extension == '.trk':
            reference = filename
        else:
            logging.error('Reference must be provided, "same" is only ' +
                          'available for Trk file.')
            return False

    if trk_header_check and extension == '.trk':
        if not is_header_compatible(filename, reference):
            logging.error('Trk file header does not match the provided ' +
                          'reference')
            return False

    timer = time.time()
    data_per_point = None
    data_per_streamline = None
    if extension in ['.trk', '.tck']:
        tractogram_obj = nib.streamlines.load(filename).tractogram
        streamlines = tractogram_obj.streamlines
        if extension == '.trk':
            data_per_point = tractogram_obj.data_per_point
            data_per_streamline = tractogram_obj.data_per_streamline

    elif extension in ['.vtk', '.fib']:
        streamlines = load_vtk_streamlines(filename)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='r')
        streamlines = list(dpy_obj.read_tracks())
        dpy_obj.close()
    logging.debug('Load %s with %s streamlines in %s seconds', filename,
                  len(streamlines), round(time.time() - timer, 3))

    sft = StatefulTractogram(streamlines,
                             reference,
                             Space.RASMM,
                             shifted_origin=shifted_origin,
                             data_per_point=data_per_point,
                             data_per_streamline=data_per_streamline)

    if to_space == Space.VOX:
        sft.to_vox()
    elif to_space == Space.VOXMM:
        sft.to_voxmm()

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'load a valid file if some coordinates are invalid.' +
                         'Please set bbox_valid_check to False and then use' +
                         'the function remove_invalid_streamlines to discard' +
                         'invalid streamlines.')

    return sft
示例#5
0
import AFQ.registration as reg
import AFQ.segmentation as seg
import AFQ.dti as dti
from AFQ.utils.volume import patch_up_roi

dpd.fetch_stanford_hardi()
hardi_dir = op.join(fetcher.dipy_home, "stanford_hardi")
hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz")
hardi_img = nib.load(hardi_fdata)
hardi_fbval = op.join(hardi_dir, "HARDI150.bval")
hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec")
file_dict = afd.read_stanford_hardi_tractography()
mapping = file_dict['mapping.nii.gz']
streamlines = file_dict['tractography_subsampled.trk']
tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
tg.to_vox()
streamlines = tg.streamlines

# streamlines = dts.Streamlines(
#     dtu.transform_tracking_output(
#         streamlines[streamlines._lengths > 10],
#         np.linalg.inv(hardi_img.affine)))


def test_segment():

    templates = afd.read_templates()
    bundles = {
        'CST_L': {
            'ROIs': [templates['CST_roi1_L'], templates['CST_roi2_L']],
            'rules': [True, True],
def empty_space_change():
    sft = StatefulTractogram([], filepath_dix['gs.nii'], Space.VOX)
    sft.to_vox()
    sft.to_voxmm()
    sft.to_rasmm()
    assert_array_equal([], sft.streamlines.data)
示例#7
0
class Segmentation:
    def __init__(self,
                 nb_points=False,
                 seg_algo='AFQ',
                 progressive=True,
                 greater_than=50,
                 rm_small_clusters=50,
                 model_clust_thr=40,
                 reduction_thr=40,
                 refine=False,
                 pruning_thr=6,
                 b0_threshold=0,
                 prob_threshold=0,
                 rng=None,
                 return_idx=False,
                 filter_by_endpoints=True,
                 dist_to_aal=4):
        """
        Segment streamlines into bundles.

        Parameters
        ----------
        nb_points : int, boolean
            Resample streamlines to nb_points number of points.
            If False, no resampling is done. Default: False

        seg_algo : string
            Algorithm for segmentation (case-insensitive):
            'AFQ': Segment streamlines into bundles,
                based on inclusion/exclusion ROIs.
            'Reco': Segment streamlines using the RecoBundles algorithm
            [Garyfallidis2017].
            Default: 'AFQ'

        rm_small_clusters : int
            Using RecoBundles Algorithm.
            Remove clusters that have less than this value
                during whole brain SLR.
            Default: 50

        progressive : boolean, optional
            Using RecoBundles Algorithm.
            Whether or not to use progressive technique
                during whole brain SLR.
            Default: True.

        greater_than : int, optional
            Using RecoBundles Algorithm.
            Keep streamlines that have length greater than this value
                during whole brain SLR.
            Default: 50.

        b0_theshold : float.
            Using AFQ Algorithm.
            All b-values with values less than or equal to `bo_threshold` are
            considered as b0s i.e. without diffusion weighting.
            Default: 0.

        prob_threshold : float.
            Using AFQ Algorithm.
            Initial cleaning of fiber groups is done using probability maps
            from [Hua2008]_. Here, we choose an average probability that
            needs to be exceeded for an individual streamline to be retained.
            Default: 0.

        rng : RandomState
            If None, creates RandomState. Used in RecoBundles Algorithm.
            Default: None.

        return_idx : bool
            Whether to return the indices in the original streamlines as part
            of the output of segmentation.

        filter_by_endpoints: bool
            Whether to filter the bundles based on their endpoints relative
            to regions defined in the AAL atlas. Applies only to the waypoint
            approach (XXX for now). Default: True.

        dist_to_aal : float
            If filter_by_endpoints is True, this is the distance from the
            endpoints to the AAL atlas ROIs that is required.

        References
        ----------
        .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
        Tract probability maps in stereotaxic spaces: analyses of white
        matter anatomy and tract-specific quantification. Neuroimage 39:
        336-347
        """
        self.logger = logging.getLogger('AFQ.Segmentation')
        self.nb_points = nb_points

        if rng is None:
            self.rng = np.random.RandomState()
        else:
            self.rng = rng

        self.seg_algo = seg_algo.lower()
        self.prob_threshold = prob_threshold
        self.b0_threshold = b0_threshold
        self.progressive = progressive
        self.greater_than = greater_than
        self.rm_small_clusters = rm_small_clusters
        self.model_clust_thr = model_clust_thr
        self.reduction_thr = reduction_thr
        self.refine = refine
        self.pruning_thr = pruning_thr
        self.return_idx = return_idx
        self.filter_by_endpoints = filter_by_endpoints
        self.dist_to_aal = dist_to_aal

    def segment(self,
                bundle_dict,
                tg,
                fdata=None,
                fbval=None,
                fbvec=None,
                mapping=None,
                reg_prealign=None,
                reg_template=None,
                b0_threshold=0,
                img_affine=None):
        """
        Segment streamlines into bundles based on either waypoint ROIs
        [Yeatman2012]_ or RecoBundles [Garyfallidis2017]_.

        Parameters
        ----------
        bundle_dict: dict
            Meta-data for the segmentation. The format is something like::

                {'name': {'ROIs':[img1, img2],
                'rules':[True, True]},
                'prob_map': img3,
                'cross_midline': False}

        tg : StatefulTractogram
            Bundles to segment

        fdata, fbval, fbvec : str
            Full path to data, bvals, bvecs

        mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional.
            A mapping between DWI space and a template. If None, mapping will
            be registered from data used in prepare_img. Default: None.

        reg_prealign : array, optional.
            The linear transformation to be applied to align input images to
            the reference space before warping under the deformation field.
            Default: None.

        reg_template : str or nib.Nifti1Image, optional.
            Template to use for registration. Default: MNI T2.

        img_affine : array, optional.
            The spatial transformation from the measurement to the scanner
            space.

        References
        ----------
        .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J.
        Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of
        White Matter Properties: Automating Fiber-Tract Quantification"
        PloS One 7 (11): e49790.

        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
        bundles using local and global streamline-based registration and
        clustering, Neuroimage, 2017.
        """
        if img_affine is not None:
            if (mapping is None or fdata is not None or fbval is not None
                    or fbvec is not None):

                self.logger.error(
                    "Provide either the full path to data, bvals, bvecs," +
                    "or provide the affine of the image and the mapping")

        self.logger.info("Preparing Segmentation Parameters")
        self.img_affine = img_affine
        self.prepare_img(fdata, fbval, fbvec)
        self.logger.info("Preprocessing Streamlines")
        self.tg = tg

        # If resampling over-write the sft:
        if self.nb_points:
            self.tg = StatefulTractogram(
                dps.set_number_of_points(self.tg.streamlines, self.nb_points),
                self.tg, self.tg.space)

            self.resample_streamlines(self.nb_points)

        self.prepare_map(mapping, reg_prealign, reg_template)
        self.bundle_dict = bundle_dict
        self.cross_streamlines()

        if self.seg_algo == "afq":
            return self.segment_afq()
        elif self.seg_algo == "reco":
            return self.segment_reco()

    def prepare_img(self, fdata, fbval, fbvec):
        """
        Prepare image data from DWI data.

        Parameters
        ----------
        fdata, fbval, fbvec : str
            Full path to data, bvals, bvecs
        """
        if self.img_affine is None:
            self.img, _, _, _ = \
                ut.prepare_data(fdata, fbval, fbvec,
                                b0_threshold=self.b0_threshold)
            self.img_affine = self.img.affine

        self.fdata = fdata
        self.fbval = fbval
        self.fbvec = fbvec

    def prepare_map(self, mapping=None, reg_prealign=None, reg_template=None):
        """
        Set mapping between DWI space and a template.

        Parameters
        ----------
        mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional.
            A mapping between DWI space and a template.
            If None, mapping will be registered from data used in prepare_img.
            Default: None.

        reg_template : str or nib.Nifti1Image, optional.
            Template to use for registration (defaults to the MNI T2)
            Default: None.

        reg_prealign : array, optional.
            The linear transformation to be applied to align input images to
            the reference space before warping under the deformation field.
            Default: None.
        """
        if reg_template is None:
            reg_template = dpd.read_mni_template()

        if mapping is None:
            gtab = dpg.gradient_table(self.fbval, self.fbvec)
            self.mapping = reg.syn_register_dwi(self.fdata, gtab)[1]
        elif isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
            if reg_prealign is None:
                reg_prealign = np.eye(4)
            if self.img is None:
                self.img, _, _, _ = \
                    ut.prepare_data(self.fdata,
                                    self.fbval,
                                    self.fbvec,
                                    b0_threshold=self.b0_threshold)
            self.mapping = reg.read_mapping(
                mapping,
                self.img,
                reg_template,
                prealign=np.linalg.inv(reg_prealign))
            # replaced the pre_align with its inverse _aNNe
            # that fixed the warped rois and the warped prob maps which were wrong otherwise
        else:
            self.mapping = mapping

    def cross_streamlines(self, tg=None, template=None, low_coord=10):
        """
        Classify the streamlines by whether they cross the midline.

        Creates a crosses attribute which is an array of booleans. Each boolean
        corresponds to a streamline, and is whether or not that streamline
        crosses the midline.

        Parameters
        ----------
        tg : StatefulTractogram class instance.

        template : nibabel.Nifti1Image class instance
            An affine transformation into a template space.
        """
        if tg is None:
            tg = self.tg
        if template is None:
            template_affine = self.img_affine
        else:
            template_affine = template.affine

        # What is the x,y,z coordinate of 0,0,0 in the template space?
        zero_coord = np.dot(np.linalg.inv(template_affine),
                            np.array([0, 0, 0, 1]))

        self.crosses = np.zeros(len(tg), dtype=bool)
        # already_split = 0
        for sl_idx, sl in enumerate(tg.streamlines):
            if np.any(sl[:, 0] > zero_coord[0]) and \
                    np.any(sl[:, 0] < zero_coord[0]):
                self.crosses[sl_idx] = True
            else:
                self.crosses[sl_idx] = False

    def _get_bundle_info(self, bundle_idx, bundle):
        """
        Get fiber probabilites and ROIs for a given bundle.
        """
        rules = self.bundle_dict[bundle]['rules']
        include_rois = []
        exclude_rois = []
        for rule_idx, rule in enumerate(rules):
            roi = self.bundle_dict[bundle]['ROIs'][rule_idx]
            if not isinstance(roi, np.ndarray):
                roi = roi.get_fdata()
            warped_roi = auv.patch_up_roi(
                self.mapping.transform_inverse(roi.astype(np.float32),
                                               interpolation='linear'))

            if rule:
                # include ROI:
                include_rois.append(np.array(np.where(warped_roi)).T)
            else:
                # Exclude ROI:
                exclude_rois.append(np.array(np.where(warped_roi)).T)


########for debugging: save the warped ROI that is actually used in segment_afq() ##########
#       path_for_debugging = '/debugpath/'
#            nib.save(nib.Nifti1Image(warped_roi.astype(np.float32),
#                                     self.img_affine),
#                     debugpath+'warpedROI_'+bundle+'as_used.nii.gz')
############################################################################################

# The probability map if doesn't exist is all ones with the same
# shape as the ROIs:
        prob_map = self.bundle_dict[bundle].get('prob_map', np.ones(roi.shape))

        if not isinstance(prob_map, np.ndarray):
            prob_map = prob_map.get_fdata()
        warped_prob_map = \
            self.mapping.transform_inverse(prob_map,
                                           interpolation='nearest')
        return warped_prob_map, include_rois, exclude_rois

    def _check_sl_with_inclusion(self, sl, include_rois, tol):
        """
        Helper function to check that a streamline is close to a list of
        inclusion ROIS.
        """
        dist = []
        for roi in include_rois:
            # Use squared Euclidean distance, because it's faster:
            dist.append(cdist(sl, roi, 'sqeuclidean'))
            if np.min(dist[-1]) > tol:
                # Too far from one of them:
                return False, []
        # Apparently you checked all the ROIs and it was close to all of them
        return True, dist

    def _check_sl_with_exclusion(self, sl, exclude_rois, tol):
        """ Helper function to check that a streamline is not too close to a
        list of exclusion ROIs.
        """
        for roi in exclude_rois:
            # Use squared Euclidean distance, because it's faster:
            if np.min(cdist(sl, roi, 'sqeuclidean')) < tol:
                return False
        # Either there are no exclusion ROIs, or you are not close to any:
        return True

    def _return_empty(self, bundle):
        """
        Helper function for segment_afq, to return an empty dict under
        some conditions.
        """

        if self.return_idx:
            self.fiber_groups[bundle] = {}
            self.fiber_groups[bundle]['sl'] = StatefulTractogram([], self.img,
                                                                 Space.VOX)
            self.fiber_groups[bundle]['idx'] = np.array([])
        else:
            self.fiber_groups[bundle] = StatefulTractogram([], self.img,
                                                           Space.VOX)

    def segment_afq(self, tg=None):
        """
        Assign streamlines to bundles using the waypoint ROI approach

        Parameters
        ----------
        tg : StatefulTractogram class instance
        """
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg

        self.tg.to_vox()
        # For expedience, we approximate each streamline as a 100 point curve:
        fgarray = np.array(_resample_tg(tg, 100))

        # comment _aNNe
        # in general, this might cause errors:
        # if rois were traversed by streamlines in just a few voxels
        # and if streamlines were so long or resolution so high
        # that one hundredth of a streamline length was more than a voxel,
        # then the contact check below (closest distance streamline to ROI < voxel width) can fail when resampling to 100 points
        # To be cartain that the resampling does not cause problems,
        # the number of resamplign points has to be larger than the length of the streamline in voxels in native space!
        # end comment

        n_streamlines = fgarray.shape[0]

        streamlines_in_bundles = np.zeros(
            (n_streamlines, len(self.bundle_dict)))
        min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2),
                                   dtype=int)
        self.fiber_groups = {}

        if self.return_idx:
            out_idx = np.arange(n_streamlines, dtype=int)

        if self.filter_by_endpoints:
            aal_atlas = afd.read_aal_atlas()['atlas'].get_fdata()
            # This atlas is not yet aligned to template space
            resample_to = self.reg_template
            if isinstance(resample_to, str):
                resample_to = nib.load(resample_to)
            allVolumes = []
            # aal atlas and more has mutiple volumes to represent overlapping areas separately
            # move through all volumes, register them to the template
            # put them together
            # safe with affine of the template
            # this puts aal atlas in the sam espace as template before it is warped to native space _aNNe
            for ii in range(aal_atlas.get_fdata().shape[-1]):
                vol = aal_atlas.get_fdata()
                vol = vol[..., ii]
                trafo = reg.resample(
                    vol,  # moving (according to reg.resample)
                    resample_to,  # static
                    aal_atlas.affine,  # moving affine
                    resample_to.affine)  # static affine
                allVolumes.append(np.asarray(trafo))
            aal_atlas = np.stack(allVolumes, axis=3)
            aal_atlas = nib.Nifti1Image(aal_atlas, resample_to.affine)
            ################for debugging: save AAL Atlas after registering to template ############
            #            path_for_debugging = '/debugpath/'
            #            nib.save(atlas_inFSL_space,debugpath+'AAL_registered_to_template.nii.gz')
            #########################################################################################

            # We need to calculate the size of a voxel, so we can transform
            # from mm to voxel units:
            R = self.img_affine[0:3, 0:3]
            vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
            dist_to_aal = self.dist_to_aal / vox_dim

        self.logger.info("Assigning Streamlines to Bundles")
        # Tolerance is set to the square of the distance to the corner
        # because we are using the squared Euclidean distance in calls to
        # `cdist` to make those calls faster.
        tol = dts.dist_to_corner(self.img_affine)**2
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Finding Streamlines for {bundle}")
            warped_prob_map, include_roi, exclude_roi = \
                self._get_bundle_info(bundle_idx, bundle)
            ########for debugging: save the warped probability map that is actually used in segment_afq() ##########
            #            path_for_debugging = '/debugpath/'
            #            nib.save(nib.Nifti1Image(warped_prob_map.astype(np.float32),
            #                                     self.img_affine),
            #                      debugpath+'warpedprobmap_'+bundle+'as_used.nii.gz')
            ############################################################################################
            fiber_probabilities = dts.values_from_volume(
                warped_prob_map, fgarray, np.eye(4))
            fiber_probabilities = np.mean(fiber_probabilities, -1)
            idx_above_prob = np.where(
                fiber_probabilities > self.prob_threshold)
            self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed"
                              " the probability threshold."))
            crosses_midline = self.bundle_dict[bundle]['cross_midline']
            for sl_idx in tqdm(idx_above_prob[0]):
                sl = tg.streamlines[sl_idx]
                if fiber_probabilities[sl_idx] > self.prob_threshold:
                    if crosses_midline is not None:
                        if self.crosses[sl_idx]:
                            # This means that the streamline does
                            # cross the midline:
                            if crosses_midline:
                                # This is what we want, keep going
                                pass
                            else:
                                # This is not what we want,
                                # skip to next streamline
                                continue

                    is_close, dist = \
                        self._check_sl_with_inclusion(sl,
                                                      include_roi,
                                                      tol)
                    if is_close:
                        is_far = \
                            self._check_sl_with_exclusion(sl,
                                                          exclude_roi,
                                                          tol)
                        if is_far:
                            min_dist_coords[sl_idx, bundle_idx, 0] =\
                                np.argmin(dist[0], 0)[0]
                            min_dist_coords[sl_idx, bundle_idx, 1] =\
                                np.argmin(dist[1], 0)[0]
                            streamlines_in_bundles[sl_idx, bundle_idx] =\
                                fiber_probabilities[sl_idx]
            self.logger.info(
                (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} "
                 "streamlines selected with waypoint ROIs"))

        # Eliminate any fibers not selected using the plane ROIs:
        possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
        tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img,
                                Space.VOX)
        if self.return_idx:
            out_idx = out_idx[possible_fibers]

        streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
        min_dist_coords = min_dist_coords[possible_fibers]
        bundle_choice = np.argmax(streamlines_in_bundles, -1)

        # We do another round through, so that we can orient all the
        # streamlines within a bundle in the same orientation with respect to
        # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
        # to ROI1).
        self.logger.info("Re-orienting streamlines to consistent directions")
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Processing {bundle}")

            select_idx = np.where(bundle_choice == bundle_idx)

            if len(select_idx[0]) == 0:
                # There's nothing here, set and move to the next bundle:
                self._return_empty(bundle)
                continue

            # Use a list here, because ArraySequence doesn't support item
            # assignment:
            select_sl = list(tg.streamlines[select_idx])
            # Sub-sample min_dist_coords:
            min_dist_coords_bundle = min_dist_coords[select_idx]
            for idx in range(len(select_sl)):
                min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
                min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
                if min0 > min1:
                    select_sl[idx] = select_sl[idx][::-1]

            # Set this to StatefulTractogram object for filtering/output:
            select_sl = StatefulTractogram(select_sl, self.img, Space.VOX)

            if self.filter_by_endpoints:
                self.logger.info("Filtering by endpoints")
                # Create binary masks and warp these into subject's DWI space:
                aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0]
                aal_idx = []
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(aal_atlas.shape[:3])
                        aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1
                        warped_roi = self.mapping.transform_inverse(
                            aal_roi, interpolation='nearest')
                        aal_idx.append(np.array(np.where(warped_roi > 0)).T)
                    else:
                        aal_idx.append(None)

                self.logger.info("Before filtering "
                                 f"{len(select_sl)} streamlines")

                new_select_sl = clean_by_endpoints(select_sl.streamlines,
                                                   aal_idx[0],
                                                   aal_idx[1],
                                                   tol=dist_to_aal,
                                                   return_idx=self.return_idx)
                # Generate immediately:
                new_select_sl = list(new_select_sl)

                # We need to check this again:
                if len(new_select_sl) == 0:
                    # There's nothing here, set and move to the next bundle:
                    self._return_empty(bundle)
                    continue

                if self.return_idx:
                    temp_select_sl = []
                    temp_select_idx = np.empty(len(new_select_sl), int)
                    for ii, ss in enumerate(new_select_sl):
                        temp_select_sl.append(ss[0])
                        temp_select_idx[ii] = ss[1]
                    select_idx = select_idx[0][temp_select_idx]
                    new_select_sl = temp_select_sl

                select_sl = StatefulTractogram(new_select_sl, self.img,
                                               Space.RASMM)

                self.logger.info("After filtering "
                                 f"{len(select_sl)} streamlines")

            if self.return_idx:
                self.fiber_groups[bundle] = {}
                self.fiber_groups[bundle]['sl'] = select_sl
                self.fiber_groups[bundle]['idx'] = out_idx[select_idx]
            else:
                self.fiber_groups[bundle] = select_sl
        return self.fiber_groups

    def segment_reco(self, tg=None):
        """
        Segment streamlines using the RecoBundles algorithm [Garyfallidis2017]

        Parameters
        ----------
        tg : StatefulTractogram class instance
            A whole-brain tractogram to be segmented.

        Returns
        -------
        fiber_groups : dict
            Keys are names of the bundles, values are Streamline objects.
            The streamlines in each object have all been oriented to have the
            same orientation (using `dts.orient_by_streamline`).
        """
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg

        fiber_groups = {}
        self.logger.info("Registering Whole-brain with SLR")
        # We start with whole-brain SLR:
        atlas = self.bundle_dict['whole_brain']
        moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr(
            atlas,
            self.tg.streamlines,
            x0='affine',
            verbose=False,
            progressive=self.progressive,
            greater_than=self.greater_than,
            rm_small_clusters=self.rm_small_clusters,
            rng=self.rng)

        # We generate our instance of RB with the moved streamlines:
        self.logger.info("Extracting Bundles")
        rb = RecoBundles(moved, verbose=False, rng=self.rng)

        # Next we'll iterate over bundles, registering each one:
        bundle_list = list(self.bundle_dict.keys())
        bundle_list.remove('whole_brain')

        self.logger.info("Assigning Streamlines to Bundles")
        for bundle in bundle_list:
            model_sl = self.bundle_dict[bundle]['sl']
            _, rec_labels = rb.recognize(model_bundle=model_sl,
                                         model_clust_thr=self.model_clust_thr,
                                         reduction_thr=self.reduction_thr,
                                         reduction_distance='mdf',
                                         slr=True,
                                         slr_metric='asymmetric',
                                         pruning_distance='mdf')

            # Use the streamlines in the original space:
            recognized_sl = tg.streamlines[rec_labels]
            if self.refine:
                _, rec_labels = rb.refine(model_sl,
                                          recognized_sl,
                                          self.model_clust_thr,
                                          reduction_thr=self.reduction_thr,
                                          pruning_thr=self.pruning_thr)
                recognized_sl = tg.streamlines[rec_labels]

            standard_sl = self.bundle_dict[bundle]['centroid']
            oriented_sl = dts.orient_by_streamline(recognized_sl, standard_sl)
            if self.return_idx:
                fiber_groups[bundle] = {}
                fiber_groups[bundle]['idx'] = rec_labels
                fiber_groups[bundle]['sl'] = StatefulTractogram(
                    oriented_sl, self.img, Space.RASMM)
            else:
                fiber_groups[bundle] = StatefulTractogram(
                    oriented_sl, self.img, Space.RASMM)
        self.fiber_groups = fiber_groups
        return fiber_groups
示例#8
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    if args.load_transfo and args.in_native_fa is None:
        parser.error('When loading a transformation, the final reference is '
                     'needed, use --in_native_fa.')
    assert_inputs_exist(parser, [args.in_dsi_tractogram, args.in_dsi_fa],
                        optional=args.in_native_fa)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram(args.in_dsi_tractogram,
                          'same',
                          bbox_valid_check=False)

    # LPS -> RAS convention in voxel space
    sft.to_vox()
    flip_axis = ['x', 'y']
    sft_fix = StatefulTractogram(sft.streamlines, args.in_dsi_fa, Space.VOXMM)
    sft_fix.to_vox()
    sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis)

    sft_flip = flip_sft(sft_fix, flip_axis)

    sft_flip.to_rasmm()
    sft_flip.streamlines._data -= [0.5, 0.5, -0.5]

    if not args.in_native_fa:
        if args.cut_invalid:
            sft_flip, _ = cut_invalid_streamlines(sft_flip)
        elif args.remove_invalid:
            sft_flip.remove_invalid_streamlines()
        save_tractogram(sft_flip,
                        args.out_tractogram,
                        bbox_valid_check=not args.keep_invalid)
    else:
        static_img = nib.load(args.in_native_fa)
        static_data = static_img.get_fdata()
        moving_img = nib.load(args.in_dsi_fa)
        moving_data = moving_img.get_fdata()

        # DSI-Studio flips the volume without changing the affine (I think)
        # So this has to be reversed (not the same problem as above)
        vox_order = get_reference_info(moving_img)[3]
        flip_axis = []
        if vox_order[0] == 'L':
            moving_data = moving_data[::-1, :, :]
            flip_axis.append('x')
        if vox_order[1] == 'P':
            moving_data = moving_data[:, ::-1, :]
            flip_axis.append('y')
        if vox_order[2] == 'I':
            moving_data = moving_data[:, :, ::-1]
            flip_axis.append('z')
        sft_flip_back = flip_sft(sft_flip, flip_axis)

        if args.load_transfo:
            transfo = np.loadtxt(args.load_transfo)
        else:
            # Sometimes DSI studio has quite a lot of skull left
            # Dipy Median Otsu does not work with FA/GFA
            if args.auto_crop:
                moving_data = cube_crop_data(moving_data)
                static_data = cube_crop_data(static_data)

            # Since DSI Studio register to AC/PC and does not save the
            # transformation We must estimate the transformation, since it's
            # rigid it is 'easy'
            c_of_mass = transform_centers_of_mass(static_data,
                                                  static_img.affine,
                                                  moving_data,
                                                  moving_img.affine)

            nbins = 32
            sampling_prop = None
            level_iters = [1000, 100, 10]
            sigmas = [3.0, 2.0, 1.0]
            factors = [3, 2, 1]
            metric = MutualInformationMetric(nbins, sampling_prop)
            affreg = AffineRegistration(metric=metric,
                                        level_iters=level_iters,
                                        sigmas=sigmas,
                                        factors=factors)
            transform = RigidTransform3D()
            rigid = affreg.optimize(static_data,
                                    moving_data,
                                    transform,
                                    None,
                                    static_img.affine,
                                    moving_img.affine,
                                    starting_affine=c_of_mass.affine)
            transfo = rigid.affine
            if args.save_transfo:
                np.savetxt(args.save_transfo, transfo)

        new_sft = transform_warp_sft(sft_flip_back,
                                     transfo,
                                     static_img,
                                     inverse=True,
                                     remove_invalid=args.remove_invalid,
                                     cut_invalid=args.cut_invalid)

        if args.cut_invalid:
            new_sft, _ = cut_invalid_streamlines(new_sft)
        elif args.remove_invalid:
            new_sft.remove_invalid_streamlines()
        save_tractogram(new_sft,
                        args.out_tractogram,
                        bbox_valid_check=not args.keep_invalid)
示例#9
0
def sfm_tracking(name=None,
                 data_path=None,
                 output_path='.',
                 Threshold=.20,
                 data_list=None,
                 return_streamlines=False,
                 save_track=True,
                 seed='.',
                 minus_ROI_mask='.',
                 one_node=False,
                 two_node=False):

    time0 = time.time()
    print("begin loading data, time:", time.time() - time0)

    if data_list == None:
        data, affine, img, labels, gtab, head_mask = get_data(name, data_path)
    else:
        data = data_list['DWI']
        affine = data_list['affine']
        img = data_list['img']
        labels = data_list['labels']
        gtab = data_list['gtab']
        head_mask = data_list['head_mask']

    if type(seed) != str:
        seed_mask = seed
    else:
        seed_mask = (labels == 2) * (head_mask == 1)

    white_matter = (labels == 2) * (head_mask == 1)
    seeds = utils.seeds_from_mask(seed_mask, affine, density=1)

    print('begin reconstruction, time:', time.time() - time0)

    from dipy.reconst.csdeconv import auto_response_ssst
    from dipy.reconst.shm import CsaOdfModel
    from dipy.data import default_sphere
    from dipy.direction import peaks_from_model

    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)

    sphere = get_sphere()
    sf_model = sfm.SparseFascicleModel(gtab,
                                       sphere=sphere,
                                       l1_ratio=0.5,
                                       alpha=0.001,
                                       response=response[0])

    pnm = peaks_from_model(sf_model,
                           data,
                           sphere,
                           relative_peak_threshold=.5,
                           min_separation_angle=25,
                           mask=white_matter,
                           parallel=True)

    stopping_criterion = ThresholdStoppingCriterion(pnm.gfa, Threshold)
    #seeds = utils.seeds_from_mask(white_matter, affine, density=1)

    print('begin tracking, time:', time.time() - time0)

    streamline_generator = LocalTracking(pnm,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)

    print('begin saving, time:', time.time() - time0)

    from dipy.io.stateful_tractogram import Space, StatefulTractogram
    from dipy.io.streamline import save_trk

    if save_track:

        sft = StatefulTractogram(streamlines, img, Space.RASMM)

        if one_node or two_node:
            sft.to_vox()
            streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node,
                                          two_node)

            if type(minus_ROI_mask) != str:
                streamlines = minus_ROI(streamlines=streamlines,
                                        ROI=minus_ROI_mask)

            sft = StatefulTractogram(streamlines, img, Space.VOX)
            sft._vox_to_rasmm()

        output = output_path + '/tractogram_sfm_' + name + '.trk'
        save_trk(sft, output, streamlines)

    if return_streamlines:
        return streamlines
示例#10
0
def determine(name=None,
              data_path=None,
              output_path='.',
              Threshold=.20,
              data_list=None,
              seed='.',
              minus_ROI_mask='.',
              one_node=False,
              two_node=False):

    time0 = time.time()
    print("begin loading data, time:", time.time() - time0)

    if data_list == None:
        data, affine, img, labels, gtab, head_mask = get_data(name, data_path)
    else:
        data = data_list['DWI']
        affine = data_list['affine']
        img = data_list['img']
        labels = data_list['labels']
        gtab = data_list['gtab']
        head_mask = data_list['head_mask']

    print(type(seed))

    if type(seed) != str:
        seed_mask = seed
    else:
        seed_mask = (labels == 2) * (head_mask == 1)

    white_matter = (labels == 2) * (head_mask == 1)
    seeds = utils.seeds_from_mask(seed_mask, affine, density=1)

    print("begin reconstruction, time:", time.time() - time0)
    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
    csd_fit = csd_model.fit(data, mask=white_matter)

    csa_model = CsaOdfModel(gtab, sh_order=6)
    gfa = csa_model.fit(data, mask=white_matter).gfa
    stopping_criterion = ThresholdStoppingCriterion(gfa, Threshold)

    #from dipy.data import small_sphere

    print("begin tracking, time:", time.time() - time0)
    detmax_dg = DeterministicMaximumDirectionGetter.from_shcoeff(
        csd_fit.shm_coeff, max_angle=30., sphere=default_sphere)
    streamline_generator = LocalTracking(detmax_dg,
                                         stopping_criterion,
                                         seeds,
                                         affine,
                                         step_size=.5)
    streamlines = Streamlines(streamline_generator)
    sft = StatefulTractogram(streamlines, img, Space.RASMM)

    if one_node or two_node:
        sft.to_vox()
        streamlines = reduct_seed_ROI(sft.streamlines, seed_mask, one_node,
                                      two_node)

        if type(minus_ROI_mask) != str:

            streamlines = minus_ROI(streamlines=streamlines,
                                    ROI=minus_ROI_mask)

        sft = StatefulTractogram(streamlines, img, Space.VOX)
        sft._vox_to_rasmm()

    print("begin saving, time:", time.time() - time0)

    output = output_path + '/tractogram_deterministic_' + name + '.trk'
    save_trk(sft, output)

    print("finished, time:", time.time() - time0)