Esempio n. 1
0
def read_aal_atlas(resample_to=None):
    """
    Reads the AAL atlas [1]_.

    Parameters
    ----------
    template : nib.Nifti1Image class instance, optional
        If provided, this is the template used and AAL atlas should be
        registered and aligned to this template


    .. [1] Tzourio-Mazoyer N, Landeau B, Papathanassiou D, Crivello F, Etard O,
           Delcroix N, Mazoyer B, Joliot M. (2002). Automated anatomical
           labeling of activations in SPM using a macroscopic anatomical
           parcellation of the MNI MRI single-subject brain. Neuroimage. 2002;
           15(1):273-89.
    """
    file_dict, folder = fetch_aal_atlas()
    out_dict = {}
    for f in file_dict:
        if f.endswith('.txt'):
            out_dict['labels'] = pd.read_csv(op.join(folder, f))
        else:
            out_dict['atlas'] = nib.load(op.join(folder, f))
    if resample_to is not None:
        data = out_dict['atlas'].get_fdata()
        oo = []
        for ii in range(data.shape[-1]):
            oo.append(
                reg.resample(data[..., ii], resample_to,
                             out_dict['atlas'].affine, resample_to.affine))
        out_dict['atlas'] = nib.Nifti1Image(np.stack(oo, -1),
                                            resample_to.affine)
    return out_dict
Esempio n. 2
0
def visualize_roi(roi,
                  affine_or_mapping=None,
                  static_img=None,
                  roi_affine=None,
                  static_affine=None,
                  reg_template=None,
                  scene=None,
                  color=np.array([1, 0, 0]),
                  opacity=1.0,
                  inline=False,
                  interact=False):
    """
    Render a region of interest into a VTK viz as a volume
    """
    if not isinstance(roi, np.ndarray):
        if isinstance(roi, str):
            roi = nib.load(roi).get_fdata()
        else:
            roi = roi.get_fdata()

    if affine_or_mapping is not None:
        if isinstance(affine_or_mapping, np.ndarray):
            # This is an affine:
            if (static_img is None or roi_affine is None
                    or static_affine is None):
                raise ValueError(
                    "If using an affine to transform an ROI, "
                    "need to also specify all of the following",
                    "inputs: `static_img`, `roi_affine`, ", "`static_affine`")
            roi = reg.resample(roi, static_img, roi_affine, static_affine)
        else:
            # Assume it is  a mapping:
            if (isinstance(affine_or_mapping, str)
                    or isinstance(affine_or_mapping, nib.Nifti1Image)):
                if reg_template is None or static_img is None:
                    raise ValueError(
                        "If using a mapping to transform an ROI, need to ",
                        "also specify all of the following inputs: ",
                        "`reg_template`, `static_img`")
                affine_or_mapping = reg.read_mapping(affine_or_mapping,
                                                     static_img, reg_template)

            roi = auv.patch_up_roi(
                affine_or_mapping.transform_inverse(
                    roi, interpolation='nearest')).astype(bool)

    if scene is None:
        scene = window.Scene()

    roi_actor = actor.contour_from_roi(roi, color=color, opacity=opacity)
    scene.add(roi_actor)

    if inline:
        tdir = tempfile.gettempdir()
        fname = op.join(tdir, "fig.png")
        window.snapshot(scene, fname=fname)
        display.display_png(display.Image(fname))

    return _inline_interact(scene, inline, interact)
Esempio n. 3
0
def _streamlines(row,
                 wm_labels,
                 odf_model="DTI",
                 directions="det",
                 n_seeds=2,
                 random_seeds=False,
                 force_recompute=False,
                 wm_fa_thresh=0.2):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(
        row, '%s_%s_streamlines.trk' % (odf_model, directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        elif odf_model == "CSD":
            params_file = _csd(row)

        dwi_img = nib.load(row['dwi_file'])
        dwi_data = dwi_img.get_data()

        if 'seg_file' in row.index:
            # If we found a white matter segmentation in the
            # expected location:
            seg_img = nib.load(row['seg_file'])
            seg_data_orig = seg_img.get_data()
            # For different sets of labels, extract all the voxels that
            # have any of these values:
            wm_mask = np.sum(
                np.concatenate([(seg_data_orig == l)[..., None]
                                for l in wm_labels], -1), -1)

            # Resample to DWI data:
            wm_mask = np.round(
                reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                             dwi_img.affine)).astype(int)
        else:
            # Otherwise, we'll identify the white matter based on FA:
            dti_fa = nib.load(_dti_fa(row)).get_data()
            wm_mask = dti_fa > wm_fa_thresh

        streamlines = aft.track(params_file,
                                directions=directions,
                                n_seeds=n_seeds,
                                random_seeds=random_seeds,
                                seed_mask=wm_mask,
                                stop_mask=wm_mask)

        aus.write_trk(streamlines_file,
                      dtu.move_streamlines(streamlines,
                                           np.linalg.inv(dwi_img.affine)),
                      affine=dwi_img.affine)

    return streamlines_file
Esempio n. 4
0
def save_wm_mask(subject):
    s3 = boto3.resource('s3')
    boto3.setup_default_session(profile_name='cirrus')
    bucket = s3.Bucket('hcp-dki')
    path = '%s/%s_white_matter_mask.nii.gz' % (subject, subject)
    if not exists(path, bucket.name):
        bucket = setup_boto()
        with tempfile.TemporaryDirectory() as tempdir:
            try:
                dwi_file = op.join(tempdir, 'data.nii.gz')
                seg_file = op.join(tempdir, 'aparc+aseg.nii.gz')
                data_files = {}
                data_files[dwi_file] = \
                    'HCP_900/%s/T1w/Diffusion/data.nii.gz' % subject
                data_files[seg_file] = \
                    'HCP_900/%s/T1w/aparc+aseg.nii.gz' % subject
                for k in data_files.keys():
                    if not op.exists(k):
                        bucket.download_file(data_files[k], k)

                seg_img = nib.load(seg_file)
                dwi_img = nib.load(dwi_file)
                seg_data_orig = seg_img.get_data()
                # Corpus callosum labels:
                cc_mask = ((seg_data_orig == 251) |
                           (seg_data_orig == 252) |
                           (seg_data_orig == 253) |
                           (seg_data_orig == 254) |
                           (seg_data_orig == 255))

                # Cerebral white matter in both hemispheres + corpus callosum
                wm_mask = ((seg_data_orig == 41) | (seg_data_orig == 2) |
                           (cc_mask))
                dwi_data = dwi_img.get_data()
                resamp_wm = np.round(reg.resample(wm_mask, dwi_data[..., 0],
                                     seg_img.affine,
                                     dwi_img.affine)).astype(int)
                wm_file = op.join(tempdir, 'wm.nii.gz')
                nib.save(nib.Nifti1Image(resamp_wm.astype(int),
                                         dwi_img.affine),
                         wm_file)
                boto3.setup_default_session(profile_name='cirrus')
                s3 = boto3.resource('s3')
                s3.meta.client.upload_file(
                        wm_file,
                        'hcp-dki',
                        path)
                return subject, True
            except Exception as err:
                return subject, err.args
    else:
        return subject, True
Esempio n. 5
0
def _resample_mask(mask_data, dwi_data, mask_affine, dwi_affine):
    '''
    Helper function
    Resamples mask to dwi if necessary
    '''
    mask_type = mask_data.dtype
    if ((dwi_data is not None) and (dwi_affine is not None)
            and (dwi_data[..., 0].shape != mask_data.shape)):
        return np.round(
            reg.resample(mask_data.astype(float), dwi_data[..., 0],
                         mask_affine, dwi_affine)).astype(mask_type)
    else:
        return mask_data
Esempio n. 6
0
def save_wm_mask(subject):
    s3 = boto3.resource('s3')
    boto3.setup_default_session(profile_name='cirrus')
    bucket = s3.Bucket('hcp-dki')
    path = '%s/%s_white_matter_mask.nii.gz' % (subject, subject)
    if not exists(path, bucket.name):
        bucket = setup_boto()
        with tempfile.TemporaryDirectory() as tempdir:
            try:
                dwi_file = op.join(tempdir, 'data.nii.gz')
                seg_file = op.join(tempdir, 'aparc+aseg.nii.gz')
                data_files = {}
                data_files[dwi_file] = \
                    'HCP_900/%s/T1w/Diffusion/data.nii.gz' % subject
                data_files[seg_file] = \
                    'HCP_900/%s/T1w/aparc+aseg.nii.gz' % subject
                for k in data_files.keys():
                    if not op.exists(k):
                        bucket.download_file(data_files[k], k)

                seg_img = nib.load(seg_file)
                dwi_img = nib.load(dwi_file)
                seg_data_orig = seg_img.get_data()
                # Corpus callosum labels:
                cc_mask = ((seg_data_orig == 251) | (seg_data_orig == 252) |
                           (seg_data_orig == 253) | (seg_data_orig == 254) |
                           (seg_data_orig == 255))

                # Cerebral white matter in both hemispheres + corpus callosum
                wm_mask = ((seg_data_orig == 41) | (seg_data_orig == 2) |
                           (cc_mask))
                dwi_data = dwi_img.get_data()
                resamp_wm = np.round(
                    reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                                 dwi_img.affine)).astype(int)
                wm_file = op.join(tempdir, 'wm.nii.gz')
                nib.save(
                    nib.Nifti1Image(resamp_wm.astype(int), dwi_img.affine),
                    wm_file)
                boto3.setup_default_session(profile_name='cirrus')
                s3 = boto3.resource('s3')
                s3.meta.client.upload_file(wm_file, 'hcp-dki', path)
                return subject, True
            except Exception as err:
                return subject, err.args
    else:
        return subject, True
Esempio n. 7
0
def _streamlines(row,
                 wm_labels,
                 odf_model="DTI",
                 directions="det",
                 force_recompute=False):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(
        row, '%s_%s_streamlines.trk' % (odf_model, directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        else:
            raise (NotImplementedError)

        seg_img = nib.load(row['seg_file'])
        dwi_img = nib.load(row['dwi_file'])
        seg_data_orig = seg_img.get_data()

        # For different sets of labels, extract all the voxels that have any
        # of these values:
        wm_mask = np.sum(
            np.concatenate([(seg_data_orig == l)[..., None]
                            for l in wm_labels], -1), -1)

        dwi_data = dwi_img.get_data()
        resamp_wm = np.round(
            reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                         dwi_img.affine)).astype(int)

        streamlines = aft.track(params_file,
                                directions='det',
                                seeds=2,
                                seed_mask=resamp_wm,
                                stop_mask=resamp_wm)

        aus.write_trk(streamlines_file, streamlines, affine=row['dwi_affine'])

    return streamlines_file
Esempio n. 8
0
def read_callosum_templates(resample_to=False):
    """Load AFQ callosum templates from file

    Returns
    -------
    dict with: keys: names of template ROIs and values: nibabel Nifti1Image
    objects from each of the ROI nifti files.
    """
    files, folder = fetch_callosum_templates()
    template_dict = {}
    for f in files:
        img = nib.load(op.join(folder, f))
        if resample_to:
            if isinstance(resample_to, str):
                resample_to = nib.load(resample_to)
            img = nib.Nifti1Image(
                reg.resample(img.get_fdata(), resample_to, img.affine,
                             resample_to.affine), resample_to.affine)
        template_dict[f.split('.')[0]] = img
    return template_dict
Esempio n. 9
0
def _streamlines(row, wm_labels, odf_model="DTI", directions="det",
                 force_recompute=False):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(row,
                                  '%s_%s_streamlines.trk' % (odf_model,
                                                             directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        else:
            raise(NotImplementedError)

        seg_img = nib.load(row['seg_file'])
        dwi_img = nib.load(row['dwi_file'])
        seg_data_orig = seg_img.get_data()

        # For different sets of labels, extract all the voxels that have any
        # of these values:
        wm_mask = np.sum(np.concatenate([(seg_data_orig == l)[..., None]
                                         for l in wm_labels], -1), -1)

        dwi_data = dwi_img.get_data()
        resamp_wm = np.round(reg.resample(wm_mask, dwi_data[..., 0],
                             seg_img.affine,
                             dwi_img.affine)).astype(int)

        streamlines = aft.track(params_file,
                                directions='det',
                                seeds=2,
                                seed_mask=resamp_wm,
                                stop_mask=resamp_wm)

        aus.write_trk(streamlines_file, streamlines,
                      affine=row['dwi_affine'])

    return streamlines_file
Esempio n. 10
0
    def _wm_mask(self, row, wm_fa_thresh=0.2):
        wm_mask_file = self._get_fname(row, '_wm_mask.nii.gz')
        if self.force_recompute or not op.exists(wm_mask_file):
            dwi_img = nib.load(row['dwi_file'])
            dwi_data = dwi_img.get_fdata()

            if 'seg_file' in row.index:
                # If we found a white matter segmentation in the
                # expected location:
                seg_img = nib.load(row['seg_file'])
                seg_data_orig = seg_img.get_fdata()
                # For different sets of labels, extract all the voxels that
                # have any of these values:
                wm_mask = np.sum(
                    np.concatenate([(seg_data_orig == l)[..., None]
                                    for l in self.wm_labels], -1), -1)

                # Resample to DWI data:
                wm_mask = np.round(
                    reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                                 dwi_img.affine)).astype(int)
                meta = dict(source=row['seg_file'], wm_labels=self.wm_labels)
            else:
                # Otherwise, we'll identify the white matter based on FA:
                fa_fname = self._dti_fa(row)
                dti_fa = nib.load(fa_fname).get_fdata()
                wm_mask = dti_fa > wm_fa_thresh
                meta = dict(source=fa_fname, fa_threshold=wm_fa_thresh)

            # Dilate to be sure to reach the gray matter:
            wm_mask = binary_dilation(wm_mask) > 0

            nib.save(nib.Nifti1Image(wm_mask.astype(int), row['dwi_affine']),
                     wm_mask_file)

            meta_fname = self._get_fname(row, '_wm_mask.json')
            afd.write_json(meta_fname, meta)

        return wm_mask_file
Esempio n. 11
0
    def segment_afq(self, tg=None):
        """
        Assign streamlines to bundles using the waypoint ROI approach

        Parameters
        ----------
        tg : StatefulTractogram class instance
        """
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg

        self.tg.to_vox()
        # For expedience, we approximate each streamline as a 100 point curve:
        fgarray = np.array(_resample_tg(tg, 100))

        # comment _aNNe
        # in general, this might cause errors:
        # if rois were traversed by streamlines in just a few voxels
        # and if streamlines were so long or resolution so high
        # that one hundredth of a streamline length was more than a voxel,
        # then the contact check below (closest distance streamline to ROI < voxel width) can fail when resampling to 100 points
        # To be cartain that the resampling does not cause problems,
        # the number of resamplign points has to be larger than the length of the streamline in voxels in native space!
        # end comment

        n_streamlines = fgarray.shape[0]

        streamlines_in_bundles = np.zeros(
            (n_streamlines, len(self.bundle_dict)))
        min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2),
                                   dtype=int)
        self.fiber_groups = {}

        if self.return_idx:
            out_idx = np.arange(n_streamlines, dtype=int)

        if self.filter_by_endpoints:
            aal_atlas = afd.read_aal_atlas()['atlas'].get_fdata()
            # This atlas is not yet aligned to template space
            resample_to = self.reg_template
            if isinstance(resample_to, str):
                resample_to = nib.load(resample_to)
            allVolumes = []
            # aal atlas and more has mutiple volumes to represent overlapping areas separately
            # move through all volumes, register them to the template
            # put them together
            # safe with affine of the template
            # this puts aal atlas in the sam espace as template before it is warped to native space _aNNe
            for ii in range(aal_atlas.get_fdata().shape[-1]):
                vol = aal_atlas.get_fdata()
                vol = vol[..., ii]
                trafo = reg.resample(
                    vol,  # moving (according to reg.resample)
                    resample_to,  # static
                    aal_atlas.affine,  # moving affine
                    resample_to.affine)  # static affine
                allVolumes.append(np.asarray(trafo))
            aal_atlas = np.stack(allVolumes, axis=3)
            aal_atlas = nib.Nifti1Image(aal_atlas, resample_to.affine)
            ################for debugging: save AAL Atlas after registering to template ############
            #            path_for_debugging = '/debugpath/'
            #            nib.save(atlas_inFSL_space,debugpath+'AAL_registered_to_template.nii.gz')
            #########################################################################################

            # We need to calculate the size of a voxel, so we can transform
            # from mm to voxel units:
            R = self.img_affine[0:3, 0:3]
            vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
            dist_to_aal = self.dist_to_aal / vox_dim

        self.logger.info("Assigning Streamlines to Bundles")
        # Tolerance is set to the square of the distance to the corner
        # because we are using the squared Euclidean distance in calls to
        # `cdist` to make those calls faster.
        tol = dts.dist_to_corner(self.img_affine)**2
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Finding Streamlines for {bundle}")
            warped_prob_map, include_roi, exclude_roi = \
                self._get_bundle_info(bundle_idx, bundle)
            ########for debugging: save the warped probability map that is actually used in segment_afq() ##########
            #            path_for_debugging = '/debugpath/'
            #            nib.save(nib.Nifti1Image(warped_prob_map.astype(np.float32),
            #                                     self.img_affine),
            #                      debugpath+'warpedprobmap_'+bundle+'as_used.nii.gz')
            ############################################################################################
            fiber_probabilities = dts.values_from_volume(
                warped_prob_map, fgarray, np.eye(4))
            fiber_probabilities = np.mean(fiber_probabilities, -1)
            idx_above_prob = np.where(
                fiber_probabilities > self.prob_threshold)
            self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed"
                              " the probability threshold."))
            crosses_midline = self.bundle_dict[bundle]['cross_midline']
            for sl_idx in tqdm(idx_above_prob[0]):
                sl = tg.streamlines[sl_idx]
                if fiber_probabilities[sl_idx] > self.prob_threshold:
                    if crosses_midline is not None:
                        if self.crosses[sl_idx]:
                            # This means that the streamline does
                            # cross the midline:
                            if crosses_midline:
                                # This is what we want, keep going
                                pass
                            else:
                                # This is not what we want,
                                # skip to next streamline
                                continue

                    is_close, dist = \
                        self._check_sl_with_inclusion(sl,
                                                      include_roi,
                                                      tol)
                    if is_close:
                        is_far = \
                            self._check_sl_with_exclusion(sl,
                                                          exclude_roi,
                                                          tol)
                        if is_far:
                            min_dist_coords[sl_idx, bundle_idx, 0] =\
                                np.argmin(dist[0], 0)[0]
                            min_dist_coords[sl_idx, bundle_idx, 1] =\
                                np.argmin(dist[1], 0)[0]
                            streamlines_in_bundles[sl_idx, bundle_idx] =\
                                fiber_probabilities[sl_idx]
            self.logger.info(
                (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} "
                 "streamlines selected with waypoint ROIs"))

        # Eliminate any fibers not selected using the plane ROIs:
        possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
        tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img,
                                Space.VOX)
        if self.return_idx:
            out_idx = out_idx[possible_fibers]

        streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
        min_dist_coords = min_dist_coords[possible_fibers]
        bundle_choice = np.argmax(streamlines_in_bundles, -1)

        # We do another round through, so that we can orient all the
        # streamlines within a bundle in the same orientation with respect to
        # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
        # to ROI1).
        self.logger.info("Re-orienting streamlines to consistent directions")
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Processing {bundle}")

            select_idx = np.where(bundle_choice == bundle_idx)

            if len(select_idx[0]) == 0:
                # There's nothing here, set and move to the next bundle:
                self._return_empty(bundle)
                continue

            # Use a list here, because ArraySequence doesn't support item
            # assignment:
            select_sl = list(tg.streamlines[select_idx])
            # Sub-sample min_dist_coords:
            min_dist_coords_bundle = min_dist_coords[select_idx]
            for idx in range(len(select_sl)):
                min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
                min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
                if min0 > min1:
                    select_sl[idx] = select_sl[idx][::-1]

            # Set this to StatefulTractogram object for filtering/output:
            select_sl = StatefulTractogram(select_sl, self.img, Space.VOX)

            if self.filter_by_endpoints:
                self.logger.info("Filtering by endpoints")
                # Create binary masks and warp these into subject's DWI space:
                aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0]
                aal_idx = []
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(aal_atlas.shape[:3])
                        aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1
                        warped_roi = self.mapping.transform_inverse(
                            aal_roi, interpolation='nearest')
                        aal_idx.append(np.array(np.where(warped_roi > 0)).T)
                    else:
                        aal_idx.append(None)

                self.logger.info("Before filtering "
                                 f"{len(select_sl)} streamlines")

                new_select_sl = clean_by_endpoints(select_sl.streamlines,
                                                   aal_idx[0],
                                                   aal_idx[1],
                                                   tol=dist_to_aal,
                                                   return_idx=self.return_idx)
                # Generate immediately:
                new_select_sl = list(new_select_sl)

                # We need to check this again:
                if len(new_select_sl) == 0:
                    # There's nothing here, set and move to the next bundle:
                    self._return_empty(bundle)
                    continue

                if self.return_idx:
                    temp_select_sl = []
                    temp_select_idx = np.empty(len(new_select_sl), int)
                    for ii, ss in enumerate(new_select_sl):
                        temp_select_sl.append(ss[0])
                        temp_select_idx[ii] = ss[1]
                    select_idx = select_idx[0][temp_select_idx]
                    new_select_sl = temp_select_sl

                select_sl = StatefulTractogram(new_select_sl, self.img,
                                               Space.RASMM)

                self.logger.info("After filtering "
                                 f"{len(select_sl)} streamlines")

            if self.return_idx:
                self.fiber_groups[bundle] = {}
                self.fiber_groups[bundle]['sl'] = select_sl
                self.fiber_groups[bundle]['idx'] = out_idx[select_idx]
            else:
                self.fiber_groups[bundle] = select_sl
        return self.fiber_groups
Esempio n. 12
0
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = reg.resample(pve_wm_data, model_params[..., 0],
                                   pve_wm_img.affine, params_img.affine)
        pve_gm_data = reg.resample(pve_gm_data, model_params[..., 0],
                                   pve_gm_img.affine, params_img.affine)
        pve_csf_data = reg.resample(pve_csf_data, model_params[..., 0],
                                    pve_csf_img.affine, params_img.affine)

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)
Esempio n. 13
0
    for bundle in bundles:
        for idx, roi in enumerate(bundles[bundle]['ROIs']):
            warped_roi = transform_inverse_roi(roi,
                                               mapping,
                                               bundle_name=bundle)
            print(roi)
            nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                     op.join(working_dir, f"{bundle}_{idx+1}.nii.gz"))

            # Add voxels that aren't there yet:
            if bundles[bundle]['rules'][idx]:
                seed_roi = np.logical_or(seed_roi, warped_roi)

        for ii, pp in enumerate(endpoint_spec[bundle].keys()):
            roi = endpoint_spec[bundle][pp]
            roi = reg.resample(roi.get_fdata(), MNI_T1w_img, roi.affine,
                               MNI_T1w_img.affine)

            warped_roi = transform_inverse_roi(roi,
                                               mapping,
                                               bundle_name=bundle)

            nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                     op.join(working_dir, f"{bundle}_{pp}.nii.gz"))

    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine),
             op.join(working_dir, 'seed_roi.nii.gz'))

    sft = aft.track(sh_coeff,
                    seed_mask=seed_roi,
                    n_seeds=5,
                    tracker="pft",