Esempio n. 1
0
def test_bundles_to_aal():
    atlas = np.zeros((20, 20, 20, 5))

    atlas[0, 0, 0, 0] = 1

    targets = afd.bundles_to_aal(["ATR_L"], atlas)
    npt.assert_equal(targets,
                     [[np.array(np.where(atlas[..., 0] == 1)).T, None]])

    atlas[0, 0, 1, 0] = 2

    targets = afd.bundles_to_aal(["ATR_L", "ATR_R"], atlas)
    npt.assert_equal(
        targets,
        [
            [np.array(np.where(atlas[..., 0] == 1)).T, None],
            [np.array(np.where(atlas[..., 0] == 2)).T, None],
        ],
    )

    targets = afd.bundles_to_aal([], atlas)
    assert len(targets) == 0

    targets = afd.bundles_to_aal(["HCC_L"], atlas)
    assert len(targets) == 1
    npt.assert_equal(targets, [[None, None]])

    targets = afd.bundles_to_aal(["VOF"], atlas)
    assert len(targets) == 1
    npt.assert_equal(targets, [[None, None]])
Esempio n. 2
0
def test_bundles_to_aal():
    atlas = np.zeros((20, 20, 20, 5))

    atlas[0, 0, 0, 0] = 1

    targets = afd.bundles_to_aal(["ATR_L"], atlas)
    npt.assert_equal(targets,
                     [[None, np.array(np.where(atlas[..., 0] == 1)).T]])

    atlas[0, 0, 1, 0] = 2

    targets = afd.bundles_to_aal(["ATR_L", "ATR_R"], atlas)
    npt.assert_equal(targets,
                     [[None, np.array(np.where(atlas[..., 0] == 1)).T],
                      [None, np.array(np.where(atlas[..., 0] == 2)).T]])
Esempio n. 3
0
    def segment_afq(self, tg=None):
        """
        Assign streamlines to bundles using the waypoint ROI approach

        Parameters
        ----------
        tg : StatefulTractogram class instance
        """
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg

        self.tg.to_vox()
        # For expedience, we approximate each streamline as a 100 point curve:
        fgarray = np.array(_resample_tg(tg, 100))

        # comment _aNNe
        # in general, this might cause errors:
        # if rois were traversed by streamlines in just a few voxels
        # and if streamlines were so long or resolution so high
        # that one hundredth of a streamline length was more than a voxel,
        # then the contact check below (closest distance streamline to ROI < voxel width) can fail when resampling to 100 points
        # To be cartain that the resampling does not cause problems,
        # the number of resamplign points has to be larger than the length of the streamline in voxels in native space!
        # end comment

        n_streamlines = fgarray.shape[0]

        streamlines_in_bundles = np.zeros(
            (n_streamlines, len(self.bundle_dict)))
        min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2),
                                   dtype=int)
        self.fiber_groups = {}

        if self.return_idx:
            out_idx = np.arange(n_streamlines, dtype=int)

        if self.filter_by_endpoints:
            aal_atlas = afd.read_aal_atlas()['atlas'].get_fdata()
            # This atlas is not yet aligned to template space
            resample_to = self.reg_template
            if isinstance(resample_to, str):
                resample_to = nib.load(resample_to)
            allVolumes = []
            # aal atlas and more has mutiple volumes to represent overlapping areas separately
            # move through all volumes, register them to the template
            # put them together
            # safe with affine of the template
            # this puts aal atlas in the sam espace as template before it is warped to native space _aNNe
            for ii in range(aal_atlas.get_fdata().shape[-1]):
                vol = aal_atlas.get_fdata()
                vol = vol[..., ii]
                trafo = reg.resample(
                    vol,  # moving (according to reg.resample)
                    resample_to,  # static
                    aal_atlas.affine,  # moving affine
                    resample_to.affine)  # static affine
                allVolumes.append(np.asarray(trafo))
            aal_atlas = np.stack(allVolumes, axis=3)
            aal_atlas = nib.Nifti1Image(aal_atlas, resample_to.affine)
            ################for debugging: save AAL Atlas after registering to template ############
            #            path_for_debugging = '/debugpath/'
            #            nib.save(atlas_inFSL_space,debugpath+'AAL_registered_to_template.nii.gz')
            #########################################################################################

            # We need to calculate the size of a voxel, so we can transform
            # from mm to voxel units:
            R = self.img_affine[0:3, 0:3]
            vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
            dist_to_aal = self.dist_to_aal / vox_dim

        self.logger.info("Assigning Streamlines to Bundles")
        # Tolerance is set to the square of the distance to the corner
        # because we are using the squared Euclidean distance in calls to
        # `cdist` to make those calls faster.
        tol = dts.dist_to_corner(self.img_affine)**2
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Finding Streamlines for {bundle}")
            warped_prob_map, include_roi, exclude_roi = \
                self._get_bundle_info(bundle_idx, bundle)
            ########for debugging: save the warped probability map that is actually used in segment_afq() ##########
            #            path_for_debugging = '/debugpath/'
            #            nib.save(nib.Nifti1Image(warped_prob_map.astype(np.float32),
            #                                     self.img_affine),
            #                      debugpath+'warpedprobmap_'+bundle+'as_used.nii.gz')
            ############################################################################################
            fiber_probabilities = dts.values_from_volume(
                warped_prob_map, fgarray, np.eye(4))
            fiber_probabilities = np.mean(fiber_probabilities, -1)
            idx_above_prob = np.where(
                fiber_probabilities > self.prob_threshold)
            self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed"
                              " the probability threshold."))
            crosses_midline = self.bundle_dict[bundle]['cross_midline']
            for sl_idx in tqdm(idx_above_prob[0]):
                sl = tg.streamlines[sl_idx]
                if fiber_probabilities[sl_idx] > self.prob_threshold:
                    if crosses_midline is not None:
                        if self.crosses[sl_idx]:
                            # This means that the streamline does
                            # cross the midline:
                            if crosses_midline:
                                # This is what we want, keep going
                                pass
                            else:
                                # This is not what we want,
                                # skip to next streamline
                                continue

                    is_close, dist = \
                        self._check_sl_with_inclusion(sl,
                                                      include_roi,
                                                      tol)
                    if is_close:
                        is_far = \
                            self._check_sl_with_exclusion(sl,
                                                          exclude_roi,
                                                          tol)
                        if is_far:
                            min_dist_coords[sl_idx, bundle_idx, 0] =\
                                np.argmin(dist[0], 0)[0]
                            min_dist_coords[sl_idx, bundle_idx, 1] =\
                                np.argmin(dist[1], 0)[0]
                            streamlines_in_bundles[sl_idx, bundle_idx] =\
                                fiber_probabilities[sl_idx]
            self.logger.info(
                (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} "
                 "streamlines selected with waypoint ROIs"))

        # Eliminate any fibers not selected using the plane ROIs:
        possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
        tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img,
                                Space.VOX)
        if self.return_idx:
            out_idx = out_idx[possible_fibers]

        streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
        min_dist_coords = min_dist_coords[possible_fibers]
        bundle_choice = np.argmax(streamlines_in_bundles, -1)

        # We do another round through, so that we can orient all the
        # streamlines within a bundle in the same orientation with respect to
        # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
        # to ROI1).
        self.logger.info("Re-orienting streamlines to consistent directions")
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Processing {bundle}")

            select_idx = np.where(bundle_choice == bundle_idx)

            if len(select_idx[0]) == 0:
                # There's nothing here, set and move to the next bundle:
                self._return_empty(bundle)
                continue

            # Use a list here, because ArraySequence doesn't support item
            # assignment:
            select_sl = list(tg.streamlines[select_idx])
            # Sub-sample min_dist_coords:
            min_dist_coords_bundle = min_dist_coords[select_idx]
            for idx in range(len(select_sl)):
                min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
                min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
                if min0 > min1:
                    select_sl[idx] = select_sl[idx][::-1]

            # Set this to StatefulTractogram object for filtering/output:
            select_sl = StatefulTractogram(select_sl, self.img, Space.VOX)

            if self.filter_by_endpoints:
                self.logger.info("Filtering by endpoints")
                # Create binary masks and warp these into subject's DWI space:
                aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0]
                aal_idx = []
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(aal_atlas.shape[:3])
                        aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1
                        warped_roi = self.mapping.transform_inverse(
                            aal_roi, interpolation='nearest')
                        aal_idx.append(np.array(np.where(warped_roi > 0)).T)
                    else:
                        aal_idx.append(None)

                self.logger.info("Before filtering "
                                 f"{len(select_sl)} streamlines")

                new_select_sl = clean_by_endpoints(select_sl.streamlines,
                                                   aal_idx[0],
                                                   aal_idx[1],
                                                   tol=dist_to_aal,
                                                   return_idx=self.return_idx)
                # Generate immediately:
                new_select_sl = list(new_select_sl)

                # We need to check this again:
                if len(new_select_sl) == 0:
                    # There's nothing here, set and move to the next bundle:
                    self._return_empty(bundle)
                    continue

                if self.return_idx:
                    temp_select_sl = []
                    temp_select_idx = np.empty(len(new_select_sl), int)
                    for ii, ss in enumerate(new_select_sl):
                        temp_select_sl.append(ss[0])
                        temp_select_idx[ii] = ss[1]
                    select_idx = select_idx[0][temp_select_idx]
                    new_select_sl = temp_select_sl

                select_sl = StatefulTractogram(new_select_sl, self.img,
                                               Space.RASMM)

                self.logger.info("After filtering "
                                 f"{len(select_sl)} streamlines")

            if self.return_idx:
                self.fiber_groups[bundle] = {}
                self.fiber_groups[bundle]['sl'] = select_sl
                self.fiber_groups[bundle]['idx'] = out_idx[select_idx]
            else:
                self.fiber_groups[bundle] = select_sl
        return self.fiber_groups
Esempio n. 4
0
    def segment_afq(self, tg=None):
        """
        Assign streamlines to bundles using the waypoint ROI approach

        Parameters
        ----------
        tg : StatefulTractogram class instance
        """
        if tg is None:
            tg = self.tg
        else:
            self.tg = tg

        self.tg.to_vox()
        # For expedience, we approximate each streamline as a 100 point curve:
        fgarray = np.array(_resample_tg(tg, 100))

        n_streamlines = fgarray.shape[0]

        streamlines_in_bundles = np.zeros(
            (n_streamlines, len(self.bundle_dict)))
        min_dist_coords = np.zeros((n_streamlines, len(self.bundle_dict), 2),
                                   dtype=int)
        self.fiber_groups = {}

        if self.return_idx:
            out_idx = np.arange(n_streamlines, dtype=int)

        if self.filter_by_endpoints:
            aal_atlas = afd.read_aal_atlas()['atlas'].get_fdata()
            # We need to calculate the size of a voxel, so we can transform
            # from mm to voxel units:
            R = self.img_affine[0:3, 0:3]
            vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
            dist_to_aal = self.dist_to_aal / vox_dim

        self.logger.info("Assigning Streamlines to Bundles")
        # Tolerance is set to the square of the distance to the corner
        # because we are using the squared Euclidean distance in calls to
        # `cdist` to make those calls faster.
        tol = dts.dist_to_corner(self.img_affine)**2
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Finding Streamlines for {bundle}")
            warped_prob_map, include_roi, exclude_roi = \
                self._get_bundle_info(bundle_idx, bundle)
            fiber_probabilities = dts.values_from_volume(
                warped_prob_map, fgarray, np.eye(4))
            fiber_probabilities = np.mean(fiber_probabilities, -1)
            idx_above_prob = np.where(
                fiber_probabilities > self.prob_threshold)
            self.logger.info((f"{len(idx_above_prob[0])} streamlines exceed"
                              " the probability threshold."))
            crosses_midline = self.bundle_dict[bundle]['cross_midline']
            for sl_idx in tqdm(idx_above_prob[0]):
                sl = tg.streamlines[sl_idx]
                if fiber_probabilities[sl_idx] > self.prob_threshold:
                    if crosses_midline is not None:
                        if self.crosses[sl_idx]:
                            # This means that the streamline does
                            # cross the midline:
                            if crosses_midline:
                                # This is what we want, keep going
                                pass
                            else:
                                # This is not what we want,
                                # skip to next streamline
                                continue

                    is_close, dist = \
                        self._check_sl_with_inclusion(sl,
                                                      include_roi,
                                                      tol)
                    if is_close:
                        is_far = \
                            self._check_sl_with_exclusion(sl,
                                                          exclude_roi,
                                                          tol)
                        if is_far:
                            min_dist_coords[sl_idx, bundle_idx, 0] =\
                                np.argmin(dist[0], 0)[0]
                            min_dist_coords[sl_idx, bundle_idx, 1] =\
                                np.argmin(dist[1], 0)[0]
                            streamlines_in_bundles[sl_idx, bundle_idx] =\
                                fiber_probabilities[sl_idx]
            self.logger.info(
                (f"{np.sum(streamlines_in_bundles[:, bundle_idx] > 0)} "
                 "streamlines selected with waypoint ROIs"))

        # Eliminate any fibers not selected using the plane ROIs:
        possible_fibers = np.sum(streamlines_in_bundles, -1) > 0
        tg = StatefulTractogram(tg.streamlines[possible_fibers], self.img,
                                Space.VOX)
        if self.return_idx:
            out_idx = out_idx[possible_fibers]

        streamlines_in_bundles = streamlines_in_bundles[possible_fibers]
        min_dist_coords = min_dist_coords[possible_fibers]
        bundle_choice = np.argmax(streamlines_in_bundles, -1)

        # We do another round through, so that we can orient all the
        # streamlines within a bundle in the same orientation with respect to
        # the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
        # to ROI1).
        self.logger.info("Re-orienting streamlines to consistent directions")
        for bundle_idx, bundle in enumerate(self.bundle_dict):
            self.logger.info(f"Processing {bundle}")

            select_idx = np.where(bundle_choice == bundle_idx)

            if len(select_idx[0]) == 0:
                # There's nothing here, set and move to the next bundle:
                self._return_empty(bundle)
                continue

            # Use a list here, because ArraySequence doesn't support item
            # assignment:
            select_sl = list(tg.streamlines[select_idx])
            # Sub-sample min_dist_coords:
            min_dist_coords_bundle = min_dist_coords[select_idx]
            for idx in range(len(select_sl)):
                min0 = min_dist_coords_bundle[idx, bundle_idx, 0]
                min1 = min_dist_coords_bundle[idx, bundle_idx, 1]
                if min0 > min1:
                    select_sl[idx] = select_sl[idx][::-1]

            # Set this to StatefulTractogram object for filtering/output:
            select_sl = StatefulTractogram(select_sl, self.img, Space.VOX)

            if self.filter_by_endpoints:
                self.logger.info("Filtering by endpoints")
                # Create binary masks and warp these into subject's DWI space:
                aal_targets = afd.bundles_to_aal([bundle], atlas=aal_atlas)[0]
                aal_idx = []
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(aal_atlas.shape[:3])
                        aal_roi[targ[:, 0], targ[:, 1], targ[:, 2]] = 1
                        warped_roi = self.mapping.transform_inverse(
                            aal_roi, interpolation='nearest')
                        aal_idx.append(np.array(np.where(warped_roi > 0)).T)
                    else:
                        aal_idx.append(None)

                self.logger.info("Before filtering "
                                 f"{len(select_sl)} streamlines")

                new_select_sl = clean_by_endpoints(select_sl.streamlines,
                                                   aal_idx[0],
                                                   aal_idx[1],
                                                   tol=dist_to_aal,
                                                   return_idx=self.return_idx)
                # Generate immediately:
                new_select_sl = list(new_select_sl)

                # We need to check this again:
                if len(new_select_sl) == 0:
                    # There's nothing here, set and move to the next bundle:
                    self._return_empty(bundle)
                    continue

                if self.return_idx:
                    temp_select_sl = []
                    temp_select_idx = np.empty(len(new_select_sl), int)
                    for ii, ss in enumerate(new_select_sl):
                        temp_select_sl.append(ss[0])
                        temp_select_idx[ii] = ss[1]
                    select_idx = select_idx[0][temp_select_idx]
                    new_select_sl = temp_select_sl

                select_sl = StatefulTractogram(new_select_sl, self.img,
                                               Space.RASMM)

                self.logger.info("After filtering "
                                 f"{len(select_sl)} streamlines")

            if self.return_idx:
                self.fiber_groups[bundle] = {}
                self.fiber_groups[bundle]['sl'] = select_sl
                self.fiber_groups[bundle]['idx'] = out_idx[select_idx]
            else:
                self.fiber_groups[bundle] = select_sl
        return self.fiber_groups
Esempio n. 5
0
def viz_indivBundle(subses_dict,
                    dwi_affine,
                    viz_backend,
                    bundle_dict,
                    data_imap,
                    mapping_imap,
                    segmentation_imap,
                    tracking_params,
                    segmentation_params,
                    reg_template,
                    best_scalar,
                    xform_volume_indiv=False,
                    cbv_lims_indiv=[None, None],
                    xform_color_by_volume_indiv=False,
                    volume_opacity_indiv=0.3,
                    n_points_indiv=40):
    mapping = mapping_imap["mapping"]
    scalar_dict = segmentation_imap["scalar_dict"]
    volume = data_imap["b0_file"]
    color_by_volume = data_imap[best_scalar + "_file"]

    start_time = time()
    volume = _viz_prepare_vol(
        volume, xform_volume_indiv, mapping, scalar_dict)
    color_by_volume = _viz_prepare_vol(
        color_by_volume, xform_color_by_volume_indiv, mapping, scalar_dict)

    flip_axes = [False, False, False]
    for i in range(3):
        flip_axes[i] = (dwi_affine[i, i] < 0)

    bundle_names = bundle_dict.keys()

    for bundle_name in bundle_names:
        logger.info(f"Generating {bundle_name} visualization...")
        uid = bundle_dict[bundle_name]['uid']
        figure = viz_backend.visualize_volume(
            volume,
            opacity=volume_opacity_indiv,
            flip_axes=flip_axes,
            interact=False,
            inline=False)
        try:
            figure = viz_backend.visualize_bundles(
                segmentation_imap["clean_bundles_file"],
                color_by_volume=color_by_volume,
                cbv_lims=cbv_lims_indiv,
                bundle_dict=bundle_dict,
                bundle=uid,
                n_points=n_points_indiv,
                flip_axes=flip_axes,
                interact=False,
                inline=False,
                figure=figure)
        except ValueError:
            logger.info(
                "No streamlines found to visualize for "
                + bundle_name)

        if segmentation_params["filter_by_endpoints"]:
            warped_rois = []
            endpoint_info = segmentation_params["endpoint_info"]
            if endpoint_info is not None:
                start_p = endpoint_info[bundle_name]['startpoint']
                end_p = endpoint_info[bundle_name]['endpoint']
                for pp in [start_p, end_p]:
                    pp = resample(
                        pp.get_fdata(),
                        reg_template,
                        pp.affine,
                        reg_template.affine).get_fdata()

                    atlas_roi = np.zeros(pp.shape)
                    atlas_roi[np.where(pp > 0)] = 1
                    warped_roi = auv.transform_inverse_roi(
                        atlas_roi,
                        mapping,
                        bundle_name=bundle_name)
                    warped_rois.append(warped_roi)
            else:
                aal_atlas = afd.read_aal_atlas(reg_template)
                atlas = aal_atlas['atlas'].get_fdata()
                aal_targets = afd.bundles_to_aal(
                    [bundle_name], atlas=atlas)[0]
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(atlas.shape[:3])
                        aal_roi[targ[:, 0],
                                targ[:, 1],
                                targ[:, 2]] = 1
                        warped_roi = auv.transform_inverse_roi(
                            aal_roi,
                            mapping,
                            bundle_name=bundle_name)
                        warped_rois.append(warped_roi)
            for i, roi in enumerate(warped_rois):
                figure = viz_backend.visualize_roi(
                    roi,
                    name=f"{bundle_name} endpoint ROI {i}",
                    flip_axes=flip_axes,
                    inline=False,
                    interact=False,
                    figure=figure)

        for i, roi in enumerate(mapping_imap["rois_file"][bundle_name]):
            figure = viz_backend.visualize_roi(
                roi,
                name=f"{bundle_name} ROI {i}",
                flip_axes=flip_axes,
                inline=False,
                interact=False,
                figure=figure)

        roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
        os.makedirs(roi_dir, exist_ok=True)
        if "no_gif" not in viz_backend.backend:
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.gif',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            viz_backend.create_gif(figure, fname)
        if "plotly" in viz_backend.backend:
            roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
            os.makedirs(roi_dir, exist_ok=True)
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.html',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            figure.write_html(fname)
    meta_fname = get_fname(
        subses_dict, '_vizIndiv.json',
        tracking_params=tracking_params,
        segmentation_params=segmentation_params)
    meta = dict(Timing=time() - start_time)
    afd.write_json(meta_fname, meta)
    return True