Esempio n. 1
0
def test_volume_registration():
    """Test volume registration."""
    import nibabel as nib
    from dipy.align import resample
    T1 = nib.load(fname_t1)
    affine = np.eye(4)
    affine[0, 3] = 10
    T1_resampled = resample(moving=T1.get_fdata(),
                            static=T1.get_fdata(),
                            moving_affine=T1.affine,
                            static_affine=T1.affine,
                            between_affine=np.linalg.inv(affine))
    for pipeline in ('rigids', ('translation', 'sdr')):
        reg_affine, sdr_morph = mne.transforms.compute_volume_registration(
            T1_resampled, T1, pipeline=pipeline, zooms=10, niter=[5])
        assert_allclose(affine, reg_affine, atol=0.25)
        T1_aligned = mne.transforms.apply_volume_registration(
            T1_resampled, T1, reg_affine, sdr_morph)
        r2 = _compute_r2(_get_img_fdata(T1_aligned), _get_img_fdata(T1))
        assert 99.9 < r2

    # check that all orders of the pipeline work
    for pipeline_len in range(1, 5):
        for pipeline in itertools.combinations(
            ('translation', 'rigid', 'affine', 'sdr'), pipeline_len):
            _validate_pipeline(pipeline)
            _validate_pipeline(list(pipeline))

    with pytest.raises(ValueError, match='Steps in pipeline are out of order'):
        _validate_pipeline(('sdr', 'affine'))

    with pytest.raises(ValueError,
                       match='Steps in pipeline should not be repeated'):
        _validate_pipeline(('affine', 'affine'))
Esempio n. 2
0
    def get_data(self, subses_dict, dwi_affine, reg_template, mapping):
        if not self.is_resampled:
            self.img = resample(
                self.img.get_fdata(),
                reg_template,
                self.img.affine,
                reg_template.affine).get_fdata()
            self.is_resampled = True

        scalar_data = mapping.transform_inverse(self.img)

        return scalar_data, dict(source=self.path)
Esempio n. 3
0
def _resample_mask(mask_data, dwi_data, mask_affine, dwi_affine):
    '''
    Helper function
    Resamples mask to dwi if necessary
    '''
    mask_type = mask_data.dtype
    if ((dwi_data is not None) and (dwi_affine is not None)
            and (dwi_data[..., 0].shape != mask_data.shape)):
        return np.round(
            resample(mask_data.astype(float), dwi_data[..., 0], mask_affine,
                     dwi_affine).get_fdata()).astype(mask_type)
    else:
        return mask_data
Esempio n. 4
0
    def get_data(self, afq_object, row):
        if not self.is_resampled:
            self.img = resample(
                self.img.get_fdata(), afq_object.reg_template_img,
                self.img.affine,
                afq_object.reg_template_img.affine).get_fdata()
            self.is_resampled = True

        mapping = afq_object._mapping(row)

        scalar_data = afq_object._mapping(row).transform_inverse(self.img)

        return scalar_data, dict(source=self.path)
Esempio n. 5
0
def test_volume_registration():
    """Test volume registration."""
    import nibabel as nib
    from dipy.align import resample
    T1 = nib.load(fname_t1)
    affine = np.eye(4)
    affine[0, 3] = 10
    T1_resampled = resample(moving=T1.get_fdata(),
                            static=T1.get_fdata(),
                            moving_affine=T1.affine,
                            static_affine=T1.affine,
                            between_affine=np.linalg.inv(affine))
    for pipeline in ('rigids', ('translation', 'sdr')):
        reg_affine, sdr_morph = mne.transforms.compute_volume_registration(
            T1_resampled, T1, pipeline=pipeline, zooms=10, niter=[5])
        assert_allclose(affine, reg_affine, atol=0.25)
        T1_aligned = mne.transforms.apply_volume_registration(
            T1_resampled, T1, reg_affine, sdr_morph)
        r2 = _compute_r2(_get_img_fdata(T1_aligned), _get_img_fdata(T1))
        assert 99.9 < r2
Esempio n. 6
0
def track(params_file,
          directions="det",
          max_angle=30.,
          sphere=None,
          seed_mask=None,
          seed_threshold=0,
          n_seeds=1,
          random_seeds=False,
          rng_seed=None,
          stop_mask=None,
          stop_threshold=0,
          step_size=0.5,
          min_length=10,
          max_length=1000,
          odf_model="DTI",
          tracker="local"):
    """
    Tractography

    Parameters
    ----------
    params_file : str, nibabel img.
        Full path to a nifti file containing CSD spherical harmonic
        coefficients, or nibabel img with model params.
    directions : str
        How tracking directions are determined.
        One of: {"det" | "prob"}
    max_angle : float, optional.
        The maximum turning angle in each step. Default: 30
    sphere : Sphere object, optional.
        The discretization of direction getting. default:
        dipy.data.default_sphere.
    seed_mask : array, optional.
        Float or binary mask describing the ROI within which we seed for
        tracking.
        Default to the entire volume (all ones).
    seed_threshold : float, optional.
        A value of the seed_mask below which tracking is terminated.
        Default to 0.
    n_seeds : int or 2D array, optional.
        The seeding density: if this is an int, it is is how many seeds in each
        voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
        array, these are the coordinates of the seeds. Unless random_seeds is
        set to True, in which case this is the total number of random seeds
        to generate within the mask.
    random_seeds : bool
        Whether to generate a total of n_seeds random seeds in the mask.
        Default: False.
    rng_seed : int
        random seed used to generate random seeds if random_seeds is
        set to True. Default: None
    stop_mask : array or str, optional.
        If array: A float or binary mask that determines a stopping criterion
        (e.g. FA).
        If tuple: it contains a sequence that is interpreted as:
        (pve_wm, pve_gm, pve_csf), each item of which is either a string
        (full path) or a nibabel img to be used in particle filtering
        tractography.
        A tuple is required if tracker is set to "pft".
        Defaults to no stopping (all ones).
    stop_threshold : float or tuple, optional.
        If float, this a value of the stop_mask below which tracking is
        terminated (and stop_mask has to be an array).
        If str, "CMC" for Continuous Map Criterion [Girard2014]_.
                "ACT" for Anatomically-constrained tractography [Smith2012]_.
        A string is required if the tracker is set to "pft".
        Defaults to 0 (this means that if no stop_mask is passed,
        we will stop only at the edge of the image).
    step_size : float, optional.
        The size (in mm) of a step of tractography. Default: 1.0
    min_length: int, optional
        The miminal length (mm) in a streamline. Default: 10
    max_length: int, optional
        The miminal length (mm) in a streamline. Default: 1000
    odf_model : str, optional
        One of {"DTI", "CSD", "DKI", "MSMT"}. Defaults to use "DTI"
    tracker : str, optional
        Which strategy to use in tracking. This can be the standard local
        tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
        One of {"local", "pft"}. Default: "local"

    Returns
    -------
    list of streamlines ()

    References
    ----------
    .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., &
        Descoteaux, M. Towards quantitative connectivity analysis: reducing
        tractography biases. NeuroImage, 98, 266-278, 2014.
    """
    logger = logging.getLogger('AFQ.tractography')

    logger.info("Loading Image...")
    if isinstance(params_file, str):
        params_img = nib.load(params_file)
    else:
        params_img = params_file

    model_params = params_img.get_fdata()
    affine = params_img.affine
    odf_model = odf_model.upper()
    directions = directions.lower()

    logger.info("Generating Seeds...")
    if isinstance(n_seeds, int):
        if seed_mask is None:
            seed_mask = np.ones(params_img.shape[:3])
        elif seed_mask.dtype != 'bool':
            seed_mask = seed_mask > seed_threshold
        if random_seeds:
            seeds = dtu.random_seeds_from_mask(seed_mask,
                                               seeds_count=n_seeds,
                                               seed_count_per_voxel=False,
                                               affine=affine,
                                               random_seed=rng_seed)
        else:
            seeds = dtu.seeds_from_mask(seed_mask,
                                        density=n_seeds,
                                        affine=affine)
    else:
        # If user provided an array, we'll use n_seeds as the seeds:
        seeds = n_seeds
    if sphere is None:
        sphere = dpd.default_sphere

    logger.info("Getting Directions...")
    if directions == "det":
        dg = DeterministicMaximumDirectionGetter
    elif directions == "prob":
        dg = ProbabilisticDirectionGetter

    if odf_model == "DTI" or odf_model == "DKI" or odf_model == "FWDTI":
        evals = model_params[..., :3]
        evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
        odf = tensor_odf(evals, evecs, sphere)
        dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
    elif odf_model == "CSD" or "MSMT":
        dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

    if tracker == "local":
        if stop_mask is None:
            stop_mask = np.ones(params_img.shape[:3])

        if stop_mask.dtype == 'bool':
            stopping_criterion = ThresholdStoppingCriterion(stop_mask, 0.5)
        else:
            stopping_criterion = ThresholdStoppingCriterion(
                stop_mask, stop_threshold)

        my_tracker = VerboseLocalTracking

    elif tracker == "pft":
        if not isinstance(stop_threshold, str):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a string ",
                "'stop_threshold' input. ",
                "Possible inputs are: 'CMC' or 'ACT'")
        if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3):
            raise RuntimeError(
                "You are using PFT tracking, but did not provide a length "
                "3 iterable for `stop_mask`. "
                "Expected a (pve_wm, pve_gm, pve_csf) tuple.")
        pves = []
        pve_imgs = []
        vox_sizes = []
        for ii, pve in enumerate(stop_mask):
            if isinstance(pve, str):
                img = nib.load(pve)
            else:
                img = pve
            pve_imgs.append(img)
            pves.append(pve_imgs[-1].get_fdata())
        average_voxel_size = np.mean(vox_sizes)
        pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs
        pve_wm_data, pve_gm_data, pve_csf_data = pves
        pve_wm_data = resample(pve_wm_data, model_params[...,
                                                         0], pve_wm_img.affine,
                               params_img.affine).get_fdata()
        pve_gm_data = resample(pve_gm_data, model_params[...,
                                                         0], pve_gm_img.affine,
                               params_img.affine).get_fdata()
        pve_csf_data = resample(pve_csf_data, model_params[..., 0],
                                pve_csf_img.affine,
                                params_img.affine).get_fdata()

        vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

        my_tracker = VerboseParticleFilteringTracking
        if stop_threshold == "CMC":
            stopping_criterion = CmcStoppingCriterion.from_pve(
                pve_wm_data,
                pve_gm_data,
                pve_csf_data,
                step_size=step_size,
                average_voxel_size=average_voxel_size)
        elif stop_threshold == "ACT":
            stopping_criterion = ActStoppingCriterion.from_pve(
                pve_wm_data, pve_gm_data, pve_csf_data)

    logger.info("Tracking...")

    return _tracking(my_tracker,
                     seeds,
                     dg,
                     stopping_criterion,
                     params_img,
                     step_size=step_size,
                     min_length=min_length,
                     max_length=max_length,
                     random_seed=rng_seed)
Esempio n. 7
0
        ax.imshow(np.take(image, [image.shape[i] // 2], axis=i).squeeze().T,
                  cmap='gray')
        ax.imshow(np.take(compare, [compare.shape[i] // 2],
                          axis=i).squeeze().T,
                  cmap='gist_heat',
                  alpha=0.5)
        ax.invert_yaxis()
        ax.axis('off')
    fig.tight_layout()


CT_orig = nib.load(op.join(misc_path, 'seeg', 'sample_seeg_CT.mgz'))

# resample to T1's definition of world coordinates
CT_resampled = resample(moving=np.asarray(CT_orig.dataobj),
                        static=np.asarray(T1.dataobj),
                        moving_affine=CT_orig.affine,
                        static_affine=T1.affine)
plot_overlay(T1, CT_resampled, 'Unaligned CT Overlaid on T1', thresh=0.95)
del CT_resampled

# %%
# Now we need to align our CT image to the T1 image.
#
# We want this to be a rigid transformation (just rotation + translation),
# so we don't do a full affine registration (that includes shear) here.
# This takes a while (~10 minutes) to execute so we skip actually running it
# here::
#
#    reg_affine, _ = mne.transforms.compute_volume_registration(
#         CT_orig, T1, pipeline='rigids')
#
Esempio n. 8
0
                roi,
                mapping,
                bundle_name=bundle)
            print(roi)
            nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                     op.join(working_dir, f"{bundle}_{idx+1}.nii.gz"))

            # Add voxels that aren't there yet:
            if bundles[bundle]['rules'][idx]:
                seed_roi = np.logical_or(seed_roi, warped_roi)

        for ii, pp in enumerate(endpoint_spec[bundle].keys()):
            roi = endpoint_spec[bundle][pp]
            roi = resample(
                roi.get_fdata(),
                MNI_T1w_img,
                roi.affine,
                MNI_T1w_img.affine).get_fdata()

            warped_roi = transform_inverse_roi(
                roi,
                mapping,
                bundle_name=bundle)

            nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
                     op.join(working_dir, f"{bundle}_{pp}.nii.gz"))

    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine),
             op.join(working_dir, 'seed_roi.nii.gz'))

    sft = aft.track(sh_coeff,
Esempio n. 9
0
def viz_indivBundle(subses_dict,
                    dwi_affine,
                    viz_backend,
                    bundle_dict,
                    data_imap,
                    mapping_imap,
                    segmentation_imap,
                    tracking_params,
                    segmentation_params,
                    reg_template,
                    best_scalar,
                    xform_volume_indiv=False,
                    cbv_lims_indiv=[None, None],
                    xform_color_by_volume_indiv=False,
                    volume_opacity_indiv=0.3,
                    n_points_indiv=40):
    mapping = mapping_imap["mapping"]
    scalar_dict = segmentation_imap["scalar_dict"]
    volume = data_imap["b0_file"]
    color_by_volume = data_imap[best_scalar + "_file"]

    start_time = time()
    volume = _viz_prepare_vol(
        volume, xform_volume_indiv, mapping, scalar_dict)
    color_by_volume = _viz_prepare_vol(
        color_by_volume, xform_color_by_volume_indiv, mapping, scalar_dict)

    flip_axes = [False, False, False]
    for i in range(3):
        flip_axes[i] = (dwi_affine[i, i] < 0)

    bundle_names = bundle_dict.keys()

    for bundle_name in bundle_names:
        logger.info(f"Generating {bundle_name} visualization...")
        uid = bundle_dict[bundle_name]['uid']
        figure = viz_backend.visualize_volume(
            volume,
            opacity=volume_opacity_indiv,
            flip_axes=flip_axes,
            interact=False,
            inline=False)
        try:
            figure = viz_backend.visualize_bundles(
                segmentation_imap["clean_bundles_file"],
                color_by_volume=color_by_volume,
                cbv_lims=cbv_lims_indiv,
                bundle_dict=bundle_dict,
                bundle=uid,
                n_points=n_points_indiv,
                flip_axes=flip_axes,
                interact=False,
                inline=False,
                figure=figure)
        except ValueError:
            logger.info(
                "No streamlines found to visualize for "
                + bundle_name)

        if segmentation_params["filter_by_endpoints"]:
            warped_rois = []
            endpoint_info = segmentation_params["endpoint_info"]
            if endpoint_info is not None:
                start_p = endpoint_info[bundle_name]['startpoint']
                end_p = endpoint_info[bundle_name]['endpoint']
                for pp in [start_p, end_p]:
                    pp = resample(
                        pp.get_fdata(),
                        reg_template,
                        pp.affine,
                        reg_template.affine).get_fdata()

                    atlas_roi = np.zeros(pp.shape)
                    atlas_roi[np.where(pp > 0)] = 1
                    warped_roi = auv.transform_inverse_roi(
                        atlas_roi,
                        mapping,
                        bundle_name=bundle_name)
                    warped_rois.append(warped_roi)
            else:
                aal_atlas = afd.read_aal_atlas(reg_template)
                atlas = aal_atlas['atlas'].get_fdata()
                aal_targets = afd.bundles_to_aal(
                    [bundle_name], atlas=atlas)[0]
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(atlas.shape[:3])
                        aal_roi[targ[:, 0],
                                targ[:, 1],
                                targ[:, 2]] = 1
                        warped_roi = auv.transform_inverse_roi(
                            aal_roi,
                            mapping,
                            bundle_name=bundle_name)
                        warped_rois.append(warped_roi)
            for i, roi in enumerate(warped_rois):
                figure = viz_backend.visualize_roi(
                    roi,
                    name=f"{bundle_name} endpoint ROI {i}",
                    flip_axes=flip_axes,
                    inline=False,
                    interact=False,
                    figure=figure)

        for i, roi in enumerate(mapping_imap["rois_file"][bundle_name]):
            figure = viz_backend.visualize_roi(
                roi,
                name=f"{bundle_name} ROI {i}",
                flip_axes=flip_axes,
                inline=False,
                interact=False,
                figure=figure)

        roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
        os.makedirs(roi_dir, exist_ok=True)
        if "no_gif" not in viz_backend.backend:
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.gif',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            viz_backend.create_gif(figure, fname)
        if "plotly" in viz_backend.backend:
            roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
            os.makedirs(roi_dir, exist_ok=True)
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.html',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            figure.write_html(fname)
    meta_fname = get_fname(
        subses_dict, '_vizIndiv.json',
        tracking_params=tracking_params,
        segmentation_params=segmentation_params)
    meta = dict(Timing=time() - start_time)
    afd.write_json(meta_fname, meta)
    return True