Exemplo n.º 1
0
def test_dti_tracking():
    fdict = fit_dti(fdata, fbval, fbvec)
    for directions in ["det", "prob"]:
        sl_serial = track(fdict['params'],
                          directions,
                          max_angle=30.,
                          sphere=None,
                          seed_mask=None,
                          seeds=seeds,
                          stop_mask=None,
                          stop_threshold=0.2,
                          step_size=0.5,
                          n_jobs=1)
        npt.assert_equal(sl_serial[0].shape[-1], 3)
        for engine in ["dask"]:
            for backend in ["threading"]:
                sl_parallel = track(fdict['params'],
                                    directions,
                                    max_angle=30.,
                                    sphere=None,
                                    seed_mask=None,
                                    seeds=seeds,
                                    stop_mask=None,
                                    stop_threshold=0.2,
                                    step_size=0.5,
                                    n_jobs=2,
                                    engine=engine,
                                    backend=backend)
                npt.assert_equal(sl_parallel[0].shape[-1], 3)

                if directions == 'det':
                    npt.assert_almost_equal(sl_parallel[0], sl_serial[0])
Exemplo n.º 2
0
def test_csd_tracking():
    for sh_order in [4, 8, 10]:
        fname = fit_csd(fdata, fbval, fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=8, lambda_=1, tau=0.1, mask=None,
                        out_dir=tmpdir.name)
        for directions in ["det", "prob"]:
            sl_serial = track(fname, directions,
                              max_angle=30., sphere=None,
                              seed_mask=None,
                              seeds=seeds,
                              stop_mask=None,
                              stop_threshold=0.2,
                              step_size=0.5,
                              n_jobs=1)
            npt.assert_equal(sl_serial[0].shape[-1], 3)
            for engine in ["dask"]:
                for backend in ["threading"]:
                    sl_parallel = track(fname, directions,
                                        max_angle=30., sphere=None,
                                        seed_mask=None,
                                        seeds=seeds,
                                        stop_mask=None,
                                        stop_threshold=0.2,
                                        step_size=0.5,
                                        n_jobs=2,
                                        engine=engine,
                                        backend=backend)
                    npt.assert_equal(sl_parallel[0].shape[-1], 3)

                    if directions == 'det':
                        npt.assert_almost_equal(sl_parallel[0], sl_serial[0])
Exemplo n.º 3
0
def test_dti_tracking():
    fdict = fit_dti(fdata, fbval, fbvec)
    for directions in ["det", "prob"]:
        sl_serial = track(fdict['params'],
                          directions,
                          max_angle=30.,
                          sphere=None,
                          seed_mask=None,
                          seeds=seeds,
                          stop_mask=None,
                          stop_threshold=0.2,
                          step_size=0.5,
                          engine="serial")
        npt.assert_equal(sl_serial[0].shape[-1], 3)
        for engine in ["dask", "joblib"]:
            for backend in ["threading"]:
                sl_parallel = track(fdict['params'],
                                    directions,
                                    max_angle=30.,
                                    sphere=None,
                                    seed_mask=None,
                                    seeds=seeds,
                                    stop_mask=None,
                                    stop_threshold=0,
                                    step_size=0.5,
                                    n_jobs=2,
                                    engine=engine,
                                    backend=backend)
                npt.assert_equal(sl_parallel[0].shape[-1], 3)

                if directions == 'det':
                    npt.assert_almost_equal(sl_parallel[0], sl_serial[0])
Exemplo n.º 4
0
def test_pft_tracking():
    for fname in [fit_dti(fdata, fbval, fbvec)['params'],
                  fit_csd(fdata, fbval, fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=8, lambda_=1, tau=0.1, mask=None,
                        out_dir=tmpdir.name)]:
        img = nib.load(fdata)
        data_shape  = img.shape
        data_affine = img.affine
        pve_wm_data = nib.Nifti1Image(np.ones(data_shape[:3]), img.affine)
        pve_gm_data = nib.Nifti1Image(np.zeros(data_shape[:3]), img.affine)
        pve_csf_data = nib.Nifti1Image(np.zeros(data_shape[:3]), img.affine)
        stop_mask = (pve_wm_data, pve_gm_data, pve_csf_data)

        for directions in ["det", "prob"]:
            for stop_threshold in ["ACT", "CMC"]:
                sl = track(
                    fname,
                    directions,
                    max_angle=30.,
                    sphere=None,
                    seed_mask=None,
                    stop_mask=stop_mask,
                    stop_threshold=stop_threshold,
                    n_seeds=1,
                    step_size=step_size,
                    min_length=min_length,
                    tracker="pft").streamlines
                npt.assert_(len(sl[0]) >= min_length * step_size)

    # Test error handling:
    with pytest.raises(RuntimeError):
        track(
            fname,
            directions,
            max_angle=30.,
            sphere=None,
            seed_mask=None,
            stop_mask=0,  # Stop mask needs to be a tuple!
            stop_threshold=stop_threshold,
            n_seeds=1,
            step_size=step_size,
            min_length=min_length,
            tracker="pft")

    with pytest.raises(RuntimeError):
        track(
            fname,
            directions,
            max_angle=30.,
            sphere=None,
            seed_mask=None,
            stop_mask=stop_mask,
            stop_threshold=None,  # Stop threshold needs to be a string!
            n_seeds=1,
            step_size=step_size,
            min_length=min_length,
            tracker="pft")
Exemplo n.º 5
0
def _streamlines(row,
                 wm_labels,
                 odf_model="DTI",
                 directions="det",
                 n_seeds=2,
                 random_seeds=False,
                 force_recompute=False,
                 wm_fa_thresh=0.2):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(
        row, '%s_%s_streamlines.trk' % (odf_model, directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        elif odf_model == "CSD":
            params_file = _csd(row)

        dwi_img = nib.load(row['dwi_file'])
        dwi_data = dwi_img.get_data()

        if 'seg_file' in row.index:
            # If we found a white matter segmentation in the
            # expected location:
            seg_img = nib.load(row['seg_file'])
            seg_data_orig = seg_img.get_data()
            # For different sets of labels, extract all the voxels that
            # have any of these values:
            wm_mask = np.sum(
                np.concatenate([(seg_data_orig == l)[..., None]
                                for l in wm_labels], -1), -1)

            # Resample to DWI data:
            wm_mask = np.round(
                reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                             dwi_img.affine)).astype(int)
        else:
            # Otherwise, we'll identify the white matter based on FA:
            dti_fa = nib.load(_dti_fa(row)).get_data()
            wm_mask = dti_fa > wm_fa_thresh

        streamlines = aft.track(params_file,
                                directions=directions,
                                n_seeds=n_seeds,
                                random_seeds=random_seeds,
                                seed_mask=wm_mask,
                                stop_mask=wm_mask)

        aus.write_trk(streamlines_file,
                      dtu.move_streamlines(streamlines,
                                           np.linalg.inv(dwi_img.affine)),
                      affine=dwi_img.affine)

    return streamlines_file
Exemplo n.º 6
0
def test_csd_tracking():
    for sh_order in [4, 8, 10]:
        fname = fit_csd(fdata,
                        fbval,
                        fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=8,
                        lambda_=1,
                        tau=0.1,
                        mask=None,
                        out_dir=tmpdir.name)
        for directions in ["det", "prob"]:
            sl_serial = track(fname,
                              directions,
                              max_angle=30.,
                              sphere=None,
                              seed_mask=None,
                              seeds=seeds,
                              stop_mask=None,
                              stop_threshold=0.2,
                              step_size=0.5,
                              n_jobs=1,
                              engine="serial")
            npt.assert_equal(sl_serial[0].shape[-1], 3)
            for engine in ["dask", "joblib"]:
                for backend in ["threading"]:
                    sl_parallel = track(fname,
                                        directions,
                                        max_angle=30.,
                                        sphere=None,
                                        seed_mask=None,
                                        seeds=seeds,
                                        stop_mask=None,
                                        stop_threshold=0.2,
                                        step_size=0.5,
                                        n_jobs=2,
                                        engine=engine,
                                        backend=backend)
                    npt.assert_equal(sl_parallel[0].shape[-1], 3)

                    if directions == 'det':
                        npt.assert_almost_equal(sl_parallel[0], sl_serial[0])
Exemplo n.º 7
0
def test_dti_tracking():
    fdict = fit_dti(fdata, fbval, fbvec)
    for directions in ["det", "prob"]:
        sl = track(fdict['params'],
                   directions,
                   max_angle=30.,
                   sphere=None,
                   seed_mask=None,
                   n_seeds=1,
                   step_size=0.5)
        npt.assert_(len(sl[0]) > 10)
Exemplo n.º 8
0
def test_dti_tracking():
    fdict = fit_dti(fdata, fbval, fbvec)
    for directions in ["det", "prob"]:
        sl = track(fdict['params'],
                   directions,
                   max_angle=30.,
                   sphere=None,
                   seed_mask=None,
                   seeds=1,
                   step_size=0.5)
        npt.assert_(len(sl[0]) > 10)
Exemplo n.º 9
0
def test_dti_tracking():
    fdict = fit_dti(fdata, fbval, fbvec)
    for directions in ["det", "prob"]:
        sl = track(fdict['params'],
                   directions,
                   max_angle=30.,
                   sphere=None,
                   seed_mask=None,
                   n_seeds=1,
                   step_size=step_size,
                   min_length=min_length).streamlines
        npt.assert_(len(sl[0]) >= min_length * step_size)
Exemplo n.º 10
0
def streamlines(subses_dict, data_imap, seed_file, stop_file,
                tracking_params):
    this_tracking_params = tracking_params.copy()

    # get odf_model
    odf_model = this_tracking_params["odf_model"]
    if odf_model == "DTI":
        params_file = data_imap["dti_params_file"]
    elif odf_model == "CSD" or odf_model == "MSMT":
        params_file = data_imap["csd_params_file"]
    elif odf_model == "DKI":
        params_file = data_imap["dki_params_file"]
    else:
        raise TypeError((
            f"The ODF model you gave ({odf_model}) was not recognized"))

    # get masks
    this_tracking_params['seed_mask'] = nib.load(seed_file).get_fdata()
    if isinstance(stop_file, str):
        this_tracking_params['stop_mask'] = nib.load(stop_file).get_fdata()
    else:
        this_tracking_params['stop_mask'] = stop_file

    # perform tractography
    start_time = time()
    sft = aft.track(params_file, **this_tracking_params)
    sft.to_vox()
    meta_directions = {
        "det": "deterministic",
        "prob": "probabilistic"}
    meta = dict(
        TractographyClass="local",
        TractographyMethod=meta_directions[
            tracking_params["directions"]],
        Count=len(sft.streamlines),
        Seeding=dict(
            ROI=seed_file,
            n_seeds=tracking_params["n_seeds"],
            random_seeds=tracking_params["random_seeds"]),
        Constraints=dict(ROI=stop_file),
        Parameters=dict(
            Units="mm",
            StepSize=tracking_params["step_size"],
            MinimumLength=tracking_params["min_length"],
            MaximumLength=tracking_params["max_length"],
            Unidirectional=False),
        Timing=time() - start_time)

    return sft, meta
Exemplo n.º 11
0
def test_csd_tracking():
    for sh_order in [4, 8, 10]:
        fname = fit_csd(fdata, fbval, fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=8, lambda_=1, tau=0.1, mask=None,
                        out_dir=tmpdir.name)
        for directions in ["det", "prob"]:
            sl = track(fname, directions,
                       max_angle=30.,
                       sphere=None,
                       seed_mask=None,
                       seeds=seeds,
                       stop_mask=None,
                       step_size=0.5)

            npt.assert_(len(sl[0]) > 10)
Exemplo n.º 12
0
def test_csd_tracking():
    for sh_order in [4, 8, 10]:
        fname = fit_csd(fdata, fbval, fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=8, lambda_=1, tau=0.1, mask=None,
                        out_dir=tmpdir.name)
        for directions in ["det", "prob"]:
            sl = track(fname, directions,
                       max_angle=30.,
                       sphere=None,
                       seed_mask=None,
                       n_seeds=seeds,
                       stop_mask=None,
                       step_size=0.5,
                       min_length=10)

            npt.assert_(len(sl[0]) > 10)
Exemplo n.º 13
0
    def _streamlines(self, row):
        odf_model = self.tracking_params["odf_model"]
        directions = self.tracking_params["directions"]

        streamlines_file = self._get_fname(
            row, f'_space-RASMM_model-{odf_model}_desc-{directions}' +
            '_tractography.trk')

        if self.force_recompute or not op.exists(streamlines_file):
            if odf_model == "DTI":
                params_file = self._dti(row)
            elif odf_model == "CSD":
                params_file = self._csd(row)
            wm_mask_fname = self._wm_mask(row)
            wm_mask = nib.load(wm_mask_fname).get_fdata().astype(bool)
            self.tracking_params['seed_mask'] = wm_mask
            self.tracking_params['stop_mask'] = wm_mask
            sft = aft.track(params_file, **self.tracking_params)
            sft.to_vox()
            meta_directions = {"det": "deterministic", "prob": "probabilistic"}

            meta = dict(TractographyClass="local",
                        TractographyMethod=meta_directions[
                            self.tracking_params["directions"]],
                        Count=len(sft.streamlines),
                        Seeding=dict(
                            ROI=wm_mask_fname,
                            n_seeds=self.tracking_params["n_seeds"],
                            random_seeds=self.tracking_params["random_seeds"]),
                        Constraints=dict(AnatomicalImage=wm_mask_fname),
                        Parameters=dict(
                            Units="mm",
                            StepSize=self.tracking_params["step_size"],
                            MinimumLength=self.tracking_params["min_length"],
                            MaximumLength=self.tracking_params["max_length"],
                            Unidirectional=False))

            meta_fname = self._get_fname(
                row, f'_space-RASMM_model-{odf_model}_desc-'
                f'{directions}_tractography.json')
            afd.write_json(meta_fname, meta)
            save_tractogram(sft, streamlines_file, bbox_valid_check=False)

        return streamlines_file
Exemplo n.º 14
0
def _streamlines(row,
                 wm_labels,
                 odf_model="DTI",
                 directions="det",
                 force_recompute=False):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(
        row, '%s_%s_streamlines.trk' % (odf_model, directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        else:
            raise (NotImplementedError)

        seg_img = nib.load(row['seg_file'])
        dwi_img = nib.load(row['dwi_file'])
        seg_data_orig = seg_img.get_data()

        # For different sets of labels, extract all the voxels that have any
        # of these values:
        wm_mask = np.sum(
            np.concatenate([(seg_data_orig == l)[..., None]
                            for l in wm_labels], -1), -1)

        dwi_data = dwi_img.get_data()
        resamp_wm = np.round(
            reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                         dwi_img.affine)).astype(int)

        streamlines = aft.track(params_file,
                                directions='det',
                                seeds=2,
                                seed_mask=resamp_wm,
                                stop_mask=resamp_wm)

        aus.write_trk(streamlines_file, streamlines, affine=row['dwi_affine'])

    return streamlines_file
Exemplo n.º 15
0
def _streamlines(row, wm_labels, odf_model="DTI", directions="det",
                 force_recompute=False):
    """
    wm_labels : list
        The values within the segmentation that are considered white matter. We
        will use this part of the image both to seed tracking (seeding
        throughout), and for stopping.
    """
    streamlines_file = _get_fname(row,
                                  '%s_%s_streamlines.trk' % (odf_model,
                                                             directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        else:
            raise(NotImplementedError)

        seg_img = nib.load(row['seg_file'])
        dwi_img = nib.load(row['dwi_file'])
        seg_data_orig = seg_img.get_data()

        # For different sets of labels, extract all the voxels that have any
        # of these values:
        wm_mask = np.sum(np.concatenate([(seg_data_orig == l)[..., None]
                                         for l in wm_labels], -1), -1)

        dwi_data = dwi_img.get_data()
        resamp_wm = np.round(reg.resample(wm_mask, dwi_data[..., 0],
                             seg_img.affine,
                             dwi_img.affine)).astype(int)

        streamlines = aft.track(params_file,
                                directions='det',
                                seeds=2,
                                seed_mask=resamp_wm,
                                stop_mask=resamp_wm)

        aus.write_trk(streamlines_file, streamlines,
                      affine=row['dwi_affine'])

    return streamlines_file
Exemplo n.º 16
0
def test_csd_local_tracking():
    for sh_order in [4, 8, 10]:
        fname = fit_csd(fdata, fbval, fbvec,
                        response=((0.0015, 0.0003, 0.0003), 100),
                        sh_order=sh_order, lambda_=1, tau=0.1, mask=None,
                        out_dir=tmpdir.name)
        for directions in ["det", "prob"]:
            sl = track(
                fname,
                directions,
                odf_model="CSD",
                max_angle=30.,
                sphere=None,
                seed_mask=None,
                n_seeds=seeds,
                stop_mask=None,
                step_size=step_size,
                min_length=min_length,
                tracker="local").streamlines

            npt.assert_(len(sl[0]) >= step_size * min_length)
Exemplo n.º 17
0
def _streamlines(row, odf_model="DTI", directions="det",
                 force_recompute=False):
    streamlines_file = _get_fname(row,
                                  '%s_%s_streamlines.trk' % (odf_model,
                                                             directions))
    if not op.exists(streamlines_file) or force_recompute:
        if odf_model == "DTI":
            params_file = _dti(row)
        else:
            raise(NotImplementedError)
        fa_file = _dti_fa(row)
        fa = nib.load(fa_file).get_data()
        wm_mask = np.zeros_like(fa)
        wm_mask[fa > 0.2] = 1
        streamlines = aft.track(params_file,
                                directions=directions,
                                seeds=1,
                                seed_mask=wm_mask,
                                stop_mask=fa)
        aus.write_trk(streamlines_file, streamlines,
                      affine=row['dwi_affine'])
    return streamlines_file
Exemplo n.º 18
0
    seed_roi = np.zeros(img.shape[:-1])
    for bundle in bundles:
        for idx, roi in enumerate(bundles[bundle]['ROIs']):
            if bundles[bundle]['rules'][idx]:
                warped_roi = transform_inverse_roi(roi,
                                                   mapping,
                                                   bundle_name=bundle)

                nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                         op.join(working_dir, f"{bundle}_{idx+1}.nii.gz"))
                # Add voxels that aren't there yet:
                seed_roi = np.logical_or(seed_roi, warped_roi)
    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine),
             op.join(working_dir, 'seed_roi.nii.gz'))
    sft = aft.track(dti_params['params'],
                    seed_mask=seed_roi,
                    stop_mask=FA_data,
                    stop_threshold=0.1)
    save_tractogram(sft,
                    op.join(working_dir, 'dti_streamlines.trk'),
                    bbox_valid_check=False)
else:
    sft = load_tractogram(op.join(working_dir, 'dti_streamlines.trk'), img)

sft.to_vox()

##########################################################################
# Segmentation
# --------
# In this stage, streamlines are tested for several criteria: whether the
# probability that they belong to a bundle is larger than a threshold (set to
# 0,per default), whether they pass through inclusion ROIs and whether they do
Exemplo n.º 19
0
            sft = StatefulTractogram(sl_xform, img, Space.RASMM)
            save_tractogram(sft, f'./{bundle}_atlas.trk')

            sl_xform = dts.Streamlines(
                dtu.transform_tracking_output(sl_xform,
                                              np.linalg.inv(img.affine)))

            for sl in sl_xform:
                sl_as_idx = sl.astype(int)
                seed_roi[sl_as_idx[:, 0], sl_as_idx[:, 1], sl_as_idx[:, 2]] = 1

    nib.save(nib.Nifti1Image(seed_roi, img.affine), 'seed_roi.nii.gz')
    sft = aft.track(dti_params['params'],
                    seed_mask=seed_roi,
                    directions='det',
                    stop_mask=FA_data,
                    stop_threshold=0.1)
    print(len(sft.streamlines))
    save_tractogram(sft, './dti_streamlines_reco.trk', bbox_valid_check=False)
else:
    sft = load_tractogram('./dti_streamlines_reco.trk', img)

print("Segmenting fiber groups...")
segmentation = seg.Segmentation(seg_algo='reco',
                                rng=np.random.RandomState(2),
                                greater_than=50,
                                rm_small_clusters=10,
                                model_clust_thr=5,
                                reduction_thr=20,
                                refine=True)
Exemplo n.º 20
0
hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec")

img = nib.load(hardi_fdata)

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
    dti_params = dti.fit_dti(hardi_fdata,
                             hardi_fbval,
                             hardi_fbvec,
                             out_dir='.')
else:
    dti_params = {'FA': './dti_FA.nii.gz', 'params': './dti_params.nii.gz'}

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    streamlines = list(aft.track(dti_params['params']))
    aus.write_trk('./dti_streamlines.trk', streamlines, affine=img.affine)
else:
    tg = nib.streamlines.load('./dti_streamlines.trk').tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

# Use only a small portion of the streamlines, for expedience:
streamlines = streamlines[::100]

templates = afd.read_templates()
bundle_names = ["CST", "ILF"]

bundles = {}
for name in bundle_names:
    for hemi in ['_R', '_L']:
        bundles[name + hemi] = {
Exemplo n.º 21
0
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    FA = nib.load(dti_params["FA"]).get_data()
    wm_mask = np.zeros_like(FA)
    wm_mask[FA > 0.2] = 1
    step_size = 1
    min_length_mm = 50
    streamlines = dts.Streamlines(
        aft.track(dti_params['params'],
                  directions="det",
                  seed_mask=wm_mask,
                  seeds=2,
                  stop_mask=FA,
                  stop_threshold=0.2,
                  step_size=step_size,
                  min_length=min_length_mm / step_size))
    aus.write_trk('./dti_streamlines.trk', streamlines, affine=img.affine)
else:
    tg = nib.streamlines.load('./dti_streamlines.trk').tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

print("We're looking at: %s streamlines" % len(streamlines))

templates = afd.read_templates()
templates['ARC_roi1_L'] = templates['SLF_roi1_L']
templates['ARC_roi1_R'] = templates['SLF_roi1_R']
templates['ARC_roi2_L'] = templates['SLFt_roi2_L']
templates['ARC_roi2_R'] = templates['SLFt_roi2_R']
Exemplo n.º 22
0
                               MNI_T1w_img.affine)

            warped_roi = transform_inverse_roi(roi,
                                               mapping,
                                               bundle_name=bundle)

            nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                     op.join(working_dir, f"{bundle}_{pp}.nii.gz"))

    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine),
             op.join(working_dir, 'seed_roi.nii.gz'))

    sft = aft.track(sh_coeff,
                    seed_mask=seed_roi,
                    n_seeds=5,
                    tracker="pft",
                    stop_mask=(pve_wm, pve_gm, pve_csf),
                    stop_threshold="ACT",
                    directions="prob",
                    odf_model="CSD")

    save_tractogram(sft,
                    op.join(working_dir, 'pft_streamlines.trk'),
                    bbox_valid_check=False)
else:
    sft = load_tractogram(op.join(working_dir, 'pft_streamlines.trk'), img)

sft.to_vox()

##########################################################################
# Segmentation
# ------------
Exemplo n.º 23
0
hardi_fbval = op.join(hardi_dir, "HARDI150.bval")
hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec")

img = nib.load(hardi_fdata)

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
    dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
                             out_dir='.')
else:
    dti_params = {'FA': './dti_FA.nii.gz',
                  'params': './dti_params.nii.gz'}

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    streamlines = list(aft.track(dti_params['params']))
    aus.write_trk('./dti_streamlines.trk', streamlines, affine=img.affine)
else:
    tg = nib.streamlines.load('./dti_streamlines.trk').tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

# Use only a small portion of the streamlines, for expedience:
streamlines = streamlines[::100]

templates = afd.read_templates()
bundle_names = ["CST", "ILF"]

bundles = {}
for name in bundle_names:
    for hemi in ['_R', '_L']:
        bundles[name + hemi] = {