コード例 #1
0
ファイル: test_registration.py プロジェクト: jyeatman/pyAFQ
def test_syn_registration():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        warped_moving, mapping = syn_registration(subset_b0,
                                                  subset_t2,
                                                  moving_affine=hardi_affine,
                                                  static_affine=MNI_T2_affine,
                                                  step_length=0.1,
                                                  metric='CC',
                                                  dim=3,
                                                  level_iters=[10, 10, 5],
                                                  sigma_diff=2.0,
                                                  prealign=None)

        npt.assert_equal(warped_moving.shape, subset_t2.shape)
        mapping_fname = op.join(tmpdir, 'mapping.nii.gz')
        write_mapping(mapping, mapping_fname)
        file_mapping = read_mapping(mapping_fname,
                                    subset_b0_img,
                                    subset_t2_img)

        # Test that it has the same effect on the data:
        warped_from_file = file_mapping.transform(subset_b0)
        npt.assert_equal(warped_from_file, warped_moving)

        # Test that it is, attribute by attribute, identical:
        for k in mapping.__dict__:
            assert (np.all(mapping.__getattribute__(k) ==
                           file_mapping.__getattribute__(k)))
コード例 #2
0
ファイル: api.py プロジェクト: jyeatman/pyAFQ
def _bundles(row, wm_labels, odf_model="DTI", directions="det",
             force_recompute=False):
    bundles_file = _get_fname(row,
                              '%s_%s_bundles.trk' % (odf_model,
                                                     directions))
    if not op.exists(bundles_file) or force_recompute:
        streamlines_file = _streamlines(row, wm_labels,
                                        odf_model=odf_model,
                                        directions=directions,
                                        force_recompute=force_recompute)
        tg = nib.streamlines.load(streamlines_file).tractogram
        sl = tg.apply_affine(np.linalg.inv(row['dwi_affine'])).streamlines
        bundle_dict = make_bundle_dict()
        reg_template = dpd.read_mni_template()
        mapping = reg.read_mapping(_mapping(row), row['dwi_file'],
                                   reg_template)
        bundles = seg.segment(row['dwi_file'],
                              row['bval_file'],
                              row['bvec_file'],
                              sl,
                              bundle_dict,
                              reg_template=reg_template,
                              mapping=mapping)
        tgram = _tgramer(bundles, bundle_dict, row['dwi_affine'])
        nib.streamlines.save(tgram, bundles_file)
    return bundles_file
コード例 #3
0
ファイル: api.py プロジェクト: jhlegarreta/pyAFQ
def _template_xform(row, reg_template, force_recompute=False):
    template_xform_file = _get_fname(row, "_template_xform.nii.gz")
    if not op.exists(template_xform_file) or force_recompute:
        reg_prealign = np.load(
            _reg_prealign(row, force_recompute=force_recompute))
        mapping = reg.read_mapping(_mapping(row,
                                            reg_template,
                                            force_recompute=force_recompute),
                                   row['dwi_file'],
                                   reg_template,
                                   prealign=np.linalg.inv(reg_prealign))

        template_xform = mapping.transform_inverse(reg_template.get_data())
        nib.save(nib.Nifti1Image(template_xform, row['dwi_affine']),
                 template_xform_file)
    return template_xform_file
コード例 #4
0
ファイル: segmentation.py プロジェクト: richford/pyAFQ
    def prepare_map(self, mapping=None, reg_prealign=None, reg_template=None):
        """
        Set mapping between DWI space and a template.
        Parameters
        ----------
        mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional.
            A mapping between DWI space and a template.
            If None, mapping will be registered from data used in prepare_img.
            Default: None.
        reg_template : str or nib.Nifti1Image, optional.
            Template to use for registration (defaults to the MNI T2)
            Default: None.
        reg_prealign : array, optional.
            The linear transformation to be applied to align input images to
            the reference space before warping under the deformation field.
            Default: None.
        """
        if reg_template is None:
            reg_template = afd.read_mni_template()

        self.reg_template = reg_template

        if mapping is None:
            if self.seg_algo == "afq" or self.reg_algo == "syn":
                gtab = dpg.gradient_table(self.fbval, self.fbvec)
                self.mapping = reg.syn_register_dwi(self.fdata,
                                                    gtab,
                                                    template=reg_template)[1]
            else:
                self.mapping = None
        elif isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
            if reg_prealign is None:
                reg_prealign = np.eye(4)
            if self.img is None:
                self.img, _, _, _ = \
                    ut.prepare_data(self.fdata,
                                    self.fbval,
                                    self.fbvec,
                                    b0_threshold=self.b0_threshold)
            self.mapping = reg.read_mapping(
                mapping,
                self.img,
                reg_template,
                prealign=np.linalg.inv(reg_prealign))
        else:
            self.mapping = mapping
コード例 #5
0
ファイル: dask_real.py プロジェクト: ValHayot/rollingprefetch
def seg_setup():
    MNI_T2_img = afd.read_mni_template()
    img = nib.load("dwi.nii")
    mapping = reg.read_mapping("mapping.nii.gz", img, MNI_T2_img)

    bundles = api.make_bundle_dict(
        bundle_names=[
            "CST", "UF", "CC_ForcepsMajor", "CC_ForcepsMinor", "OR", "VOF"
        ],
        seg_algo="reco80",
        resample_to=MNI_T2_img,
    )  # CST ARC

    return {
        "MNI_T2_img": MNI_T2_img,
        "img": img,
        "mapping": mapping,
        "bundles": bundles,
    }
コード例 #6
0
ファイル: api.py プロジェクト: jhlegarreta/pyAFQ
def _bundles(row,
             wm_labels,
             bundle_dict,
             reg_template,
             odf_model="DTI",
             directions="det",
             n_seeds=2,
             random_seeds=False,
             force_recompute=False):
    bundles_file = _get_fname(row,
                              '%s_%s_bundles.trk' % (odf_model, directions))
    if not op.exists(bundles_file) or force_recompute:
        streamlines_file = _streamlines(row,
                                        wm_labels,
                                        odf_model=odf_model,
                                        directions=directions,
                                        n_seeds=n_seeds,
                                        random_seeds=random_seeds,
                                        force_recompute=force_recompute)
        tg = nib.streamlines.load(streamlines_file).tractogram
        sl = tg.apply_affine(np.linalg.inv(row['dwi_affine'])).streamlines

        reg_prealign = np.load(
            _reg_prealign(row, force_recompute=force_recompute))

        mapping = reg.read_mapping(_mapping(row, reg_template),
                                   row['dwi_file'],
                                   reg_template,
                                   prealign=np.linalg.inv(reg_prealign))

        bundles = seg.segment(row['dwi_file'],
                              row['bval_file'],
                              row['bvec_file'],
                              sl,
                              bundle_dict,
                              reg_template=reg_template,
                              mapping=mapping)

        tgram = aus.bundles_to_tgram(bundles, bundle_dict, row['dwi_affine'])
        nib.streamlines.save(tgram, bundles_file)
    return bundles_file
コード例 #7
0
ファイル: api.py プロジェクト: Anneef/pyAFQ-1
    def _export_rois(self, row):
        if self.use_prealign:
            reg_prealign = np.load(self._reg_prealign(row))
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None

        mapping = reg.read_mapping(self._mapping(row),
                                   row['dwi_file'],
                                   self.reg_template,
                                   prealign=reg_prealign_inv)

        rois_dir = op.join(row['results_dir'], 'ROIs')
        os.makedirs(rois_dir, exist_ok=True)

        for bundle in self.bundle_dict:
            for ii, roi in enumerate(self.bundle_dict[bundle]['ROIs']):

                if self.bundle_dict[bundle]['rules'][ii]:
                    inclusion = 'include'
                else:
                    inclusion = 'exclude'

                warped_roi = auv.patch_up_roi((mapping.transform_inverse(
                    roi.get_fdata(), interpolation='linear')) > 0).astype(int)

                fname = op.split(
                    self._get_fname(
                        row,
                        f'_desc-ROI-{bundle}-{ii + 1}-{inclusion}.nii.gz'))

                fname = op.join(fname[0], rois_dir, fname[1])

                # Cast to float32, so that it can be read in by MI-Brain:
                nib.save(
                    nib.Nifti1Image(warped_roi.astype(np.float32),
                                    row['dwi_affine']), fname)
                meta = dict()
                meta_fname = fname.split('.')[0] + '.json'
                afd.write_json(meta_fname, meta)
コード例 #8
0
ファイル: test_registration.py プロジェクト: richford/pyAFQ
def test_slr_registration():
    # have to import subject sls
    file_dict = afd.read_stanford_hardi_tractography()
    streamlines = file_dict['tractography_subsampled.trk']

    # have to import sls atlas
    afd.fetch_hcp_atlas_16_bundles()
    atlas_fname = op.join(afd.afq_home, 'hcp_atlas_16_bundles',
                          'Atlas_in_MNI_Space_16_bundles', 'whole_brain',
                          'whole_brain_MNI.trk')
    hcp_atlas = load_tractogram(atlas_fname, 'same', bbox_valid_check=False)

    with nbtmp.InTemporaryDirectory() as tmpdir:
        mapping = slr_registration(streamlines,
                                   hcp_atlas.streamlines,
                                   moving_affine=subset_b0_img.affine,
                                   static_affine=subset_t2_img.affine,
                                   moving_shape=subset_b0_img.shape,
                                   static_shape=subset_t2_img.shape,
                                   progressive=False,
                                   greater_than=10,
                                   rm_small_clusters=1,
                                   rng=np.random.RandomState(seed=8))
        warped_moving = mapping.transform(subset_b0)

        npt.assert_equal(warped_moving.shape, subset_t2.shape)
        mapping_fname = op.join(tmpdir, 'mapping.npy')
        write_mapping(mapping, mapping_fname)
        file_mapping = read_mapping(mapping_fname, subset_b0_img,
                                    subset_t2_img)

        # Test that it has the same effect on the data:
        warped_from_file = file_mapping.transform(subset_b0)
        npt.assert_equal(warped_from_file, warped_moving)

        # Test that it is, attribute by attribute, identical:
        for k in mapping.__dict__:
            assert (np.all(
                mapping.__getattribute__(k) == file_mapping.__getattribute__(
                    k)))
コード例 #9
0
if not op.exists(op.join(working_dir, 'mapping.nii.gz')):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    b0 = np.mean(img.get_fdata()[..., gtab.b0s_mask], -1)
    # Prealign using affine registration
    _, prealign = affine_registration(b0, MNI_T2_img.get_fdata(), img.affine,
                                      MNI_T2_img.affine)

    # Then register using a non-linear registration using the affine for
    # prealignment
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata,
                                                 gtab,
                                                 prealign=prealign)
    reg.write_mapping(mapping, op.join(working_dir, 'mapping.nii.gz'))
else:
    mapping = reg.read_mapping(op.join(working_dir, 'mapping.nii.gz'), img,
                               MNI_T2_img)

##########################################################################
# Read in bundle specification
# -------------------------------------------
# The waypoint ROIs, in addition to bundle probability maps are stored in this
# data structure. The templates are first resampled into the MNI space, before
# they are brought into the subject's individual native space.
# For speed, we only segment two bundles here.

bundles = api.BundleDict(["CST", "ARC"], resample_to=MNI_T2_img)

##########################################################################
# Tracking
# --------
# Streamlines are generate using DTI and a deterministic tractography
コード例 #10
0
    dti_params = {'FA': './dti_FA.nii.gz', 'params': './dti_params.nii.gz'}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()

print("Registering to template...")
MNI_T2_img = afd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata,
                                                 gtab,
                                                 template=MNI_T2_img)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

bundle_names = ["CST", "UF", "CC_ForcepsMajor", "CC_ForcepsMinor"]
bundles = api.make_bundle_dict(bundle_names=bundle_names, seg_algo="reco")

print("Tracking...")
if not op.exists('dti_streamlines_reco.trk'):
    seed_roi = np.zeros(img.shape[:-1])
    for bundle in bundles:
        if bundle != 'whole_brain':
            sl_xform = dts.Streamlines(
                dtu.transform_tracking_output(bundles[bundle]['sl'],
                                              MNI_T2_img.affine))

            delta = dts.values_from_volume(mapping.backward, sl_xform,
                                           np.eye(4))
コード例 #11
0
ファイル: pipeline.py プロジェクト: jyeatman/pyAFQ
else:
    dti_params = {'FA': './dti_FA.nii.gz',
                  'MD': './dti_MD.nii.gz',
                  'RD': './dti_RD.nii.gz',
                  'AD': './dti_AD.nii.gz',
                  'params': './dti_params.nii.gz'}

print("Registering to template...")
MNI_T2_img = dpd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(fbval, fbvec)
    mapping = reg.syn_register_dwi(fdata, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    FA = nib.load(dti_params["FA"]).get_data()
    wm_mask = np.zeros_like(FA)
    wm_mask[FA > 0.2] = 1
    step_size = 1
    min_length_mm = 50
    streamlines = dts.Streamlines(
        aft.track(dti_params['params'],
                  directions="det",
                  seed_mask=wm_mask,
                  seeds=2,
                  stop_mask=FA,
                  stop_threshold=0.2,
コード例 #12
0
ファイル: segmentation.py プロジェクト: soichih/pyAFQ
def segment(fdata,
            fbval,
            fbvec,
            streamlines,
            bundles,
            reg_template=None,
            mapping=None,
            as_generator=True,
            clip_to_roi=True,
            **reg_kwargs):
    """
    Segment streamlines into bundles.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundles: dict
        The format is something like::

             {'name': {'ROIs':[img, img], 'rules':[True, True]}}

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate this.

    as_generator : bool, optional
        Whether to generate the streamlines here, or return generators.
        Default: True.

    clip_to_roi : bool, optional
        Whether to clip the streamlines between the ROIs
    """
    img, data, gtab, mask = ut.prepare_data(fdata, fbval, fbvec)
    xform_sl = [
        s for s in dtu.move_streamlines(streamlines, np.linalg.inv(img.affine))
    ]

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata,
                                       gtab,
                                       template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_groups = {}
    for bundle in bundles:
        select_sl = xform_sl
        for ROI, rule in zip(bundles[bundle]['ROIs'],
                             bundles[bundle]['rules']):
            data = ROI.get_data()
            warped_ROI = patch_up_roi(
                mapping.transform_inverse(data, interpolation='nearest'))
            # This function requires lists as inputs:
            select_sl = dts.select_by_rois(select_sl,
                                           [warped_ROI.astype(bool)], [rule])
        # Next, we reorient each streamline according to an ARBITRARY, but
        # CONSISTENT order. To do this, we use the first ROI for which the rule
        # is True as the first one to pass through, and the last ROI for which
        # the rule is True as the last one to pass through:

        # Indices where the 'rule' is True:
        idx = np.where(bundles[bundle]['rules'])

        orient_ROIs = [
            bundles[bundle]['ROIs'][idx[0][0]],
            bundles[bundle]['ROIs'][idx[0][-1]]
        ]

        select_sl = dts.orient_by_rois(select_sl,
                                       orient_ROIs[0].get_data(),
                                       orient_ROIs[1].get_data(),
                                       as_generator=True)

        #  XXX Implement clipping to the ROIs
        #  if clip_to_roi:
        #    dts.clip()

        if as_generator:
            fiber_groups[bundle] = select_sl
        else:
            fiber_groups[bundle] = list(select_sl)

    return fiber_groups
コード例 #13
0
ファイル: main.py プロジェクト: brainlife/app-AFQ-seg
def main():
    with open('config.json') as config_json:
        config = json.load(config_json)

    #Paths to data
    data_file = str(config['data_file'])
    data_bval = str(config['data_bval'])
    data_bvec = str(config['data_bvec'])

    img = nib.load(data_file)
    """
	print("Calculating DTI...")
	if not op.exists('./dti_FA.nii.gz'):
	    dti_params = dti.fit_dti(data_file, data_bval, data_bvec, out_dir='.')
	else:
	    dti_params = {'FA': './dti_FA.nii.gz',
			  'params': './dti_params.nii.gz'}
	"""
    #tg = nib.streamlines.load('track.trk').tractogram

    tg = nib.streamlines.load(config['tck_data']).tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

    # Use only a small portion of the streamlines, for expedience:
    #streamlines = streamlines[::100]

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [
                    templates[name + '_roi1' + hemi],
                    templates[name + '_roi1' + hemi]
                ],
                'rules': [True, True]
            }

    print("Registering to template...")
    if not op.exists('mapping.nii.gz'):
        gtab = gradient_table(data_bval, data_bvec)
        mapping = reg.syn_register_dwi(data_file, gtab)
        reg.write_mapping(mapping, './mapping.nii.gz')
    else:
        mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
    """
	MNI_T2_img = dpd.read_mni_template()
	bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
	gtab = gradient_table(bvals, bvecs, b0_threshold=100)
	mapping = reg.syn_register_dwi(data_file, gtab)
	reg.write_mapping(mapping, './mapping.nii.gz')
	"""

    print("Segmenting fiber groups...")
    fiber_groups = seg.segment(data_file,
                               data_bval,
                               data_bvec,
                               streamlines,
                               bundles,
                               reg_template=MNI_T2_img,
                               mapping=mapping,
                               as_generator=False,
                               affine=img.affine)

    path = os.getcwd() + '/tract1/'
    if not os.path.exists(path):
        os.makedirs(path)

    for fg in fiber_groups:
        streamlines = fiber_groups[fg]
        fname = fg + ".tck"
        trg = nib.streamlines.Tractogram(streamlines,
                                         affine_to_rasmm=img.affine)
        nib.streamlines.save(trg, path + fname)
コード例 #14
0
ファイル: segmentation.py プロジェクト: akeshavan/pyAFQ
def segment(fdata, fbval, fbvec, streamlines, bundle_dict, mapping,
            reg_prealign=None, b0_threshold=0, reg_template=None,
            prob_threshold=0):
    """
    Segment streamlines into bundles based on inclusion ROIs.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundle_dict: dict
        The format is something like::

            {'name': {'ROIs':[img1, img2],
            'rules':[True, True]},
            'prob_map': img3,
            'cross_midline': False}

    mapping : a DiffeomorphicMapping object
        Used to align the ROIs to the data.

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate
        this.

    prob_threshold : float.
        Initial cleaning of fiber groups is done using probability maps from
        [Hua2008]_. Here, we choose an average probability that needs to be
        exceeded for an individual streamline to be retained. Default: 0.

    References
    ----------
    .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
       Tract probability maps in stereotaxic spaces: analyses of white
       matter anatomy and tract-specific quantification. Neuroimage 39:
       336-347
    """
    img, _, gtab, _ = ut.prepare_data(fdata, fbval, fbvec,
                                      b0_threshold=b0_threshold)

    tol = dts.dist_to_corner(img.affine)

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    # Classify the streamlines and split those that: 1) cross the
    # midline, and 2) pass under 10 mm below the mid-point of their
    # representation in the template space:
    xform_sl, crosses = split_streamlines(streamlines, img)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        if reg_prealign is None:
            reg_prealign = np.eye(4)
        mapping = reg.read_mapping(mapping, img, reg_template,
                                   prealign=reg_prealign)

    fiber_probabilities = np.zeros((len(xform_sl), len(bundle_dict)))

    # For expedience, we approximate each streamline as a 100 point curve:
    fgarray = _resample_bundle(xform_sl, 100)
    streamlines_in_bundles = np.zeros((len(xform_sl), len(bundle_dict)))
    min_dist_coords = np.zeros((len(xform_sl), len(bundle_dict), 2))

    fiber_groups = {}

    for bundle_idx, bundle in enumerate(bundle_dict):
        rules = bundle_dict[bundle]['rules']
        include_rois = []
        exclude_rois = []
        for rule_idx, rule in enumerate(rules):
            roi = bundle_dict[bundle]['ROIs'][rule_idx]
            if not isinstance(roi, np.ndarray):
                roi = roi.get_data()
            warped_roi = auv.patch_up_roi(
                (mapping.transform_inverse(
                    roi,
                    interpolation='linear')) > 0)

            if rule:
                # include ROI:
                include_rois.append(np.array(np.where(warped_roi)).T)
            else:
                # Exclude ROI:
                exclude_rois.append(np.array(np.where(warped_roi)).T)

        crosses_midline = bundle_dict[bundle]['cross_midline']

        # The probability map if doesn't exist is all ones with the same
        # shape as the ROIs:
        prob_map = bundle_dict[bundle].get('prob_map', np.ones(roi.shape))

        if not isinstance(prob_map, np.ndarray):
            prob_map = prob_map.get_data()
        warped_prob_map = mapping.transform_inverse(prob_map,
                                                    interpolation='nearest')
        fiber_probabilities = dts.values_from_volume(warped_prob_map,
                                                     fgarray)
        fiber_probabilities = np.mean(fiber_probabilities, -1)

        for sl_idx, sl in enumerate(xform_sl):
            if fiber_probabilities[sl_idx] > prob_threshold:
                if crosses_midline is not None:
                    if crosses[sl_idx]:
                        # This means that the streamline does
                        # cross the midline:
                        if crosses_midline:
                            # This is what we want, keep going
                            pass
                        else:
                            # This is not what we want, skip to next streamline
                            continue

                is_close, dist = _check_sl_with_inclusion(sl, include_rois,
                                                          tol)
                if is_close:
                    is_far = _check_sl_with_exclusion(sl, exclude_rois,
                                                      tol)
                    if is_far:
                        min_dist_coords[sl_idx, bundle_idx, 0] =\
                            np.argmin(dist[0], 0)[0]
                        min_dist_coords[sl_idx, bundle_idx, 1] =\
                            np.argmin(dist[1], 0)[0]
                        streamlines_in_bundles[sl_idx, bundle_idx] =\
                            fiber_probabilities[sl_idx]

    # Eliminate any fibers not selected using the plane ROIs:
    possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
    xform_sl = xform_sl[possible_fibers]
    streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
    min_dist_coords = min_dist_coords[possible_fibers]
    bundle_choice = np.argmax(streamlines_in_bundles, -1)

    # We do another round through, so that we can orient all the
    # streamlines within a bundle in the same orientation with respect to
    # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
    # to ROI1).
    for bundle_idx, bundle in enumerate(bundle_dict):
        select_idx = np.where(bundle_choice == bundle_idx)
        # Use a list here, because Streamlines don't support item assignment:
        select_sl = list(xform_sl[select_idx])
        if len(select_sl) == 0:
            fiber_groups[bundle] = dts.Streamlines([])
            # There's nothing here, move to the next bundle:
            continue

        # Sub-sample min_dist_coords:
        min_dist_coords_bundle = min_dist_coords[select_idx]
        for idx in range(len(select_sl)):
            min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
            min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
            if min0 > min1:
                select_sl[idx] = select_sl[idx][::-1]
        # Set this to nibabel.Streamlines object for output:
        select_sl = dts.Streamlines(select_sl)
        fiber_groups[bundle] = select_sl

    return fiber_groups
コード例 #15
0
ファイル: segmentation.py プロジェクト: yeatmanlab/pyAFQ
def segment(fdata, fbval, fbvec, streamlines, bundles,
            reg_template=None, mapping=None, as_generator=True, **reg_kwargs):
    """

    generate : bool
        Whether to generate the streamlines here, or return generators.

    reg_template : template to use for registration (defaults to the MNI T2)

    bundles: dict
        The format is something like::

             {'name': {'ROIs':[img, img], 'rules':[True, True]}}


    """
    img, data, gtab, mask = ut.prepare_data(fdata, fbval, fbvec)
    xform_sl = [s for s in dtu.move_streamlines(streamlines,
                                                np.linalg.inv(img.affine))]

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata, gtab, template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_groups = {}
    for bundle in bundles:
        select_sl = xform_sl
        for ROI, rule in zip(bundles[bundle]['ROIs'],
                             bundles[bundle]['rules']):
            data = ROI.get_data()
            warped_ROI = patch_up_roi(mapping.transform_inverse(
                data,
                interpolation='nearest'))
            # This function requires lists as inputs:
            select_sl = dts.select_by_rois(select_sl,
                                           [warped_ROI.astype(bool)],
                                           [rule])
        # Next, we reorient each streamline according to an ARBITRARY, but
        # CONSISTENT order. To do this, we use the first ROI for which the rule
        # is True as the first one to pass through, and the last ROI for which
        # the rule is True as the last one to pass through:

        # Indices where the 'rule' is True:
        idx = np.where(bundles[bundle]['rules'])

        orient_ROIs = [bundles[bundle]['ROIs'][idx[0][0]],
                       bundles[bundle]['ROIs'][idx[0][-1]]]

        select_sl = dts.orient_by_rois(select_sl,
                                       orient_ROIs[0].get_data(),
                                       orient_ROIs[1].get_data(),
                                       in_place=True)
        if as_generator:
            fiber_groups[bundle] = select_sl
        else:
            fiber_groups[bundle] = list(select_sl)

    return fiber_groups
コード例 #16
0
ファイル: segmentation.py プロジェクト: jyeatman/pyAFQ
def segment(fdata, fbval, fbvec, streamlines, bundles,
            reg_template=None, mapping=None, prob_threshold=0,
            **reg_kwargs):
    """
    Segment streamlines into bundles based on inclusion ROIs.

    Parameters
    ----------
    fdata, fbval, fbvec : str
        Full path to data, bvals, bvecs

    streamlines : list of 2D arrays
        Each array is a streamline, shape (3, N).

    bundles: dict
        The format is something like::

            {'name': {'ROIs':[img1, img2],
            'rules':[True, True]},
            'prob_map': img3,
            'cross_midline': False}

    reg_template : str or nib.Nifti1Image, optional.
        Template to use for registration (defaults to the MNI T2)

    mapping : DiffeomorphicMap object, str or nib.Nifti1Image, optional
        A mapping between DWI space and a template. Defaults to generate
        this.

    prob_threshold : float.
        Initial cleaning of fiber groups is done using probability maps from
        [Hua2008]_. Here, we choose an average probability that needs to be
        exceeded for an individual streamline to be retained. Default: 0.

    References
    ----------
    .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
       Tract probability maps in stereotaxic spaces: analyses of white
       matter anatomy and tract-specific quantification. Neuroimage 39:
       336-347
    """
    img, _, gtab, _ = ut.prepare_data(fdata, fbval, fbvec)
    tol = dts.dist_to_corner(img.affine)

    xform_sl = dts.Streamlines(dtu.move_streamlines(streamlines,
                                                    np.linalg.inv(img.affine)))

    if reg_template is None:
        reg_template = dpd.read_mni_template()

    if mapping is None:
        mapping = reg.syn_register_dwi(fdata, gtab, template=reg_template,
                                       **reg_kwargs)

    if isinstance(mapping, str) or isinstance(mapping, nib.Nifti1Image):
        mapping = reg.read_mapping(mapping, img, reg_template)

    fiber_probabilities = np.zeros((len(xform_sl), len(bundles)))

    # For expedience, we approximate each streamline as a 100 point curve:
    fgarray = _resample_bundle(xform_sl, 100)
    streamlines_in_bundles = np.zeros((len(xform_sl), len(bundles)))
    min_dist_coords = np.zeros((len(xform_sl), len(bundles), 2))

    fiber_groups = {}

    for bundle_idx, bundle in enumerate(bundles):
        # Get the ROI coordinates:
        ROI0 = bundles[bundle]['ROIs'][0]
        ROI1 = bundles[bundle]['ROIs'][1]
        if not isinstance(ROI0, np.ndarray):
            ROI0 = ROI0.get_data()

        warped_ROI0 = patch_up_roi(
            mapping.transform_inverse(
                ROI0,
                interpolation='nearest')).astype(bool)

        if not isinstance(ROI1, np.ndarray):
            ROI1 = ROI1.get_data()

        warped_ROI1 = patch_up_roi(
            mapping.transform_inverse(
                ROI1,
                interpolation='nearest')).astype(bool)

        roi_coords0 = np.array(np.where(warped_ROI0)).T
        roi_coords1 = np.array(np.where(warped_ROI1)).T

        crosses_midline = bundles[bundle]['cross_midline']

        # The probability map if doesn't exist is all ones with the same
        # shape as the ROIs:
        prob_map = bundles[bundle].get('prob_map', np.ones(ROI0.shape))
        if not isinstance(prob_map, np.ndarray):
            prob_map = prob_map.get_data()
        warped_prob_map = mapping.transform_inverse(prob_map,
                                                    interpolation='nearest')
        fiber_probabilities = dts.values_from_volume(warped_prob_map,
                                                     fgarray)
        fiber_probabilities = np.mean(fiber_probabilities, -1)

        for sl_idx, sl in enumerate(xform_sl):
            if fiber_probabilities[sl_idx] > prob_threshold:
                if crosses_midline is not None:
                    if (np.any(sl[:, 0] > img.shape[0] // 2) and
                            np.any(sl[:, 0] < img.shape[0] // 2)):
                        # This means that the streamline does
                        # cross the midline:
                        if crosses_midline:
                            # This is what we want, keep going
                            pass
                        else:
                            # This is not what we want, skip to next streamline
                            continue
                dist0 = cdist(sl, roi_coords0, 'euclidean')
                if np.min(dist0) <= tol:
                    dist1 = cdist(sl, roi_coords1, 'euclidean')
                    if np.min(dist1) <= tol:
                        min_dist_coords[sl_idx, bundle_idx, 0] =\
                            np.argmin(dist0, 0)[0]
                        min_dist_coords[sl_idx, bundle_idx, 1] =\
                            np.argmin(dist1, 0)[0]
                        streamlines_in_bundles[sl_idx, bundle_idx] =\
                            fiber_probabilities[sl_idx]

    # Eliminate any fibers not selected using the plane ROIs:
    possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
    xform_sl = xform_sl[possible_fibers]
    streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
    min_dist_coords = min_dist_coords[possible_fibers]
    bundle_choice = np.argmax(streamlines_in_bundles, -1)

    for bundle_idx, bundle in enumerate(bundles):
        print(bundle)
        select_idx = np.where(bundle_choice == bundle_idx)
        # Use a list here, because Streamlines don't support item assignment:
        select_sl = list(xform_sl[select_idx])
        # Sub-sample min_dist_coords:
        min_dist_coords_bundle = min_dist_coords[select_idx]
        if len(select_sl) == 0:
            fiber_groups[bundle] = dts.Streamlines([])
            # There's nothing here, move to the next bundle:
            continue

        for idx in range(len(select_sl)):
            min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
            min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
            if min0 > min1:
                select_sl[idx] = select_sl[idx][::-1]
        # We'll set this to Streamlines object for the next steps (e.g.,
        # cleaning) because these objects support indexing with arrays:
        select_sl = dts.Streamlines(select_sl)
        fiber_groups[bundle] = select_sl

    return fiber_groups