示例#1
0
def plot_bundles_with_metric(bundle_path,
                             endings_path,
                             brain_mask_path,
                             bundle,
                             metrics,
                             output_path,
                             tracking_format="trk_legacy",
                             show_color_bar=True):
    import seaborn as sns  # import in function to avoid error if not installed (this is only needed in this function)
    from dipy.viz import actor, window
    from tractseg.libs import vtk_utils

    def _add_extra_point_to_last_streamline(sl):
        # Coloring broken as soon as all streamlines have same number of points -> why???
        # Add one number to last streamline to make it have a different number
        sl[-1] = np.append(sl[-1], [sl[-1][-1]], axis=0)
        return sl

    # Settings
    NR_SEGMENTS = 100
    ANTI_INTERPOL_MULT = 1  # increase number of points to avoid interpolation to blur the colors
    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane
    # colors = np.array(sns.color_palette("coolwarm", NR_SEGMENTS))  # colormap blue to red (does not fit to colorbar)
    colors = np.array(sns.light_palette(
        "red", NR_SEGMENTS))  # colormap only red, which fits to color_bar
    img_size = (1000, 1000)

    # Tractometry skips first and last element. Therefore we only have 98 instead of 100 elements.
    # Here we duplicate the first and last element to get back to 100 elements
    metrics = list(metrics)
    metrics = np.array([metrics[0]] + metrics + [metrics[-1]])

    metrics_max = metrics.max()
    metrics_min = metrics.min()
    if metrics_max == metrics_min:
        metrics = np.zeros(len(metrics))
    else:
        metrics = img_utils.scale_to_range(
            metrics,
            range=(0, 99))  # range needs to be same as segments in colormap

    orientation = dataset_specific_utils.get_optimal_orientation_for_bundle(
        bundle)

    # Load mask
    beginnings_img = nib.load(endings_path)
    beginnings = beginnings_img.get_data()
    for i in range(1):
        beginnings = binary_dilation(beginnings)

    # Load trackings
    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(bundle_path)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(bundle_path)
        streamlines = sl_file.streamlines

    # Reduce streamline count
    streamlines = streamlines[::2]

    # Reorder to make all streamlines have same start region
    streamlines = fiber_utils.add_to_each_streamline(streamlines, 0.5)
    streamlines_new = []
    for idx, sl in enumerate(streamlines):
        startpoint = sl[0]
        # Flip streamline if not in right order
        if beginnings[int(startpoint[0]),
                      int(startpoint[1]),
                      int(startpoint[2])] == 0:
            sl = sl[::-1, :]
        streamlines_new.append(sl)
    streamlines = fiber_utils.add_to_each_streamline(streamlines_new, -0.5)

    if algorithm == "distance_map" or algorithm == "equal_dist":
        streamlines = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    elif algorithm == "cutting_plane":
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines,
            max_nr_points=NR_SEGMENTS,
            ANTI_INTERPOL_MULT=ANTI_INTERPOL_MULT)

    # Cut start and end by percentage
    # streamlines = FiberUtils.resample_fibers(streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    # remove = int((NR_SEGMENTS * ANTI_INTERPOL_MULT) * 0.15)  # remove X% in beginning and end
    # streamlines = np.array(streamlines)[:, remove:-remove, :]
    # streamlines = list(streamlines)

    if algorithm == "equal_dist":
        segment_idxs = []
        for i in range(len(streamlines)):
            segment_idxs.append(list(range(NR_SEGMENTS * ANTI_INTERPOL_MULT)))
        segment_idxs = np.array(segment_idxs)

    elif algorithm == "distance_map":
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines, k=1)

    elif algorithm == "cutting_plane":
        streamlines_resamp = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroid = Streamlines(clusters.centroids)[0]
        # index of the middle cluster
        middle_idx = int(NR_SEGMENTS / 2) * ANTI_INTERPOL_MULT
        middle_point = centroid[middle_idx]
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)
        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            base_idx = 1000  # use higher index to avoid negative numbers for area below middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

    # Add extra point otherwise coloring BUG
    streamlines = _add_extra_point_to_last_streamline(streamlines)

    renderer = window.Renderer()
    colors_all = []  # final shape will be [nr_streamlines, nr_points, 3]
    for jdx, sl in enumerate(streamlines):
        colors_sl = []
        for idx, p in enumerate(sl):
            if idx >= len(segment_idxs[jdx]):
                seg_idx = segment_idxs[jdx][idx - 1]
            else:
                seg_idx = segment_idxs[jdx][idx]

            m = metrics[int(seg_idx / ANTI_INTERPOL_MULT)]
            color = colors[int(m)]
            colors_sl.append(color)
        colors_all.append(
            colors_sl
        )  # this can not be converted to numpy array because last element has one more elem

    sl_actor = actor.streamtube(streamlines,
                                colors=colors_all,
                                linewidth=0.2,
                                opacity=1)
    renderer.add(sl_actor)

    # plot brain mask
    mask = nib.load(brain_mask_path).get_data()
    cont_actor = vtk_utils.contour_from_roi_smooth(
        mask,
        affine=beginnings_img.affine,
        color=[.9, .9, .9],
        opacity=.2,
        smoothing=50)
    renderer.add(cont_actor)

    if show_color_bar:
        lut_cmap = actor.colormap_lookup_table(scale_range=(metrics_min,
                                                            metrics_max),
                                               hue_range=(0.0, 0.0),
                                               saturation_range=(0.0, 1.0))
        renderer.add(actor.scalar_bar(lut_cmap))

    if orientation == "sagittal":
        renderer.set_camera(position=(-412.95, -34.38, 80.15),
                            focal_point=(102.46, -16.96, -11.71),
                            view_up=(0.1806, 0.0, 0.9835))
    elif orientation == "coronal":
        renderer.set_camera(position=(-48.63, 360.31, 98.37),
                            focal_point=(-20.16, 92.89, 36.02),
                            view_up=(-0.0047, -0.2275, 0.9737))
    elif orientation == "axial":
        pass
    else:
        raise ValueError("Invalid orientation provided")

    # Use this to interatively get new camera angle
    # window.show(renderer, size=img_size, reset_camera=False)
    # print(renderer.get_camera())

    window.record(renderer, out_path=output_path, size=img_size)
示例#2
0
def evaluate_along_streamlines(scalar_img,
                               streamlines,
                               beginnings,
                               nr_points,
                               dilate=0,
                               predicted_peaks=None,
                               affine=None):
    # Runtime:
    # - default:                2.7s (test),    56s (all),      10s (test 4 bundles, 100 points)
    # - map_coordinate order 1: 1.9s (test),    26s (all),       6s (test 4 bundles, 100 points)
    # - map_coordinate order 3: 2.2s (test),    33s (all),
    # - values_from_volume:     2.5s (test),    43s (all),
    # - AFQ:                      ?s (test),     ?s (all),      85s  (test 4 bundles, 100 points)
    # => AFQ a lot slower than others

    streamlines = list(
        transform_streamlines(streamlines, np.linalg.inv(affine)))

    for i in range(dilate):
        beginnings = binary_dilation(beginnings)
    beginnings = beginnings.astype(np.uint8)
    streamlines = _orient_to_same_start_region(streamlines, beginnings)
    if predicted_peaks is not None:
        # scalar img can also be orig peaks
        best_orig_peaks = fiber_utils.get_best_original_peaks(
            predicted_peaks, scalar_img, peak_len_thr=0.00001)
        scalar_img = np.linalg.norm(best_orig_peaks, axis=-1)

    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane | afq

    if algorithm == "equal_dist":
        ### Sampling ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        values = map_coordinates(scalar_img, np.array(streamlines).T, order=1)
        ### Aggregation ###
        values_mean = np.array(values).mean(axis=1)
        values_std = np.array(values).std(axis=1)
        return values_mean, values_std

    if algorithm == "distance_map":  # cKDTree

        ### Sampling ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        values = map_coordinates(scalar_img, np.array(streamlines).T, order=1)

        ### Aggregating by cKDTree approach ###
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        if len(centroids) > 1:
            print("WARNING: number clusters > 1 ({})".format(len(centroids)))
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines,
                                                        k=1)  # (2000, 100)

        values_t = np.array(values).T  # (2000, 100)

        # If we want to take weighted mean like in AFQ:
        # weights = dsa.gaussian_weights(Streamlines(streamlines))
        # values_t = weights * values_t
        # return np.sum(values_t, 0), None

        results_dict = defaultdict(list)
        for idx, sl in enumerate(values_t):
            for jdx, seg in enumerate(sl):
                results_dict[segment_idxs[idx, jdx]].append(seg)

        if len(results_dict.keys()) < nr_points:
            print(
                "WARNING: found less than required points. Filling up with centroid values."
            )
            centroid_values = map_coordinates(scalar_img,
                                              np.array([centroids[0]]).T,
                                              order=1)
            for i in range(nr_points):
                if len(results_dict[i]) == 0:
                    results_dict[i].append(np.array(centroid_values).T[0, i])

        results_mean = []
        results_std = []
        for key in sorted(results_dict.keys()):
            value = results_dict[key]
            if len(value) > 0:
                results_mean.append(np.array(value).mean())
                results_std.append(np.array(value).std())
            else:
                print("WARNING: empty segment")
                results_mean.append(0)
                results_std.append(0)
        return results_mean, results_std

    elif algorithm == "cutting_plane":
        # This will resample all streamline to have equally distant points (resulting in a different number of points
        # in each streamline). Then the "middle" of the tract will be estimated taking the middle element of the
        # centroid (estimated with QuickBundles). Then each streamline the point closest to the "middle" will be
        # calculated and points will be indexed for each streamline starting from the middle. Then averaging across
        # all streamlines will be done by taking the mean for points with same indices.

        ### Sampling ###
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines, max_nr_points=nr_points)
        # map_coordinates does not allow streamlines with different lengths -> use values_from_volume
        values = np.array(
            values_from_volume(scalar_img, streamlines, affine=np.eye(4))).T

        ### Aggregating by Cutting Plane approach ###
        # Resample to all fibers having same number of points -> needed for QuickBundles
        streamlines_resamp = fiber_utils.resample_fibers(streamlines,
                                                         nb_points=nr_points)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroids = Streamlines(clusters.centroids)

        # index of the middle cluster
        middle_idx = int(nr_points / 2)
        middle_point = centroids[0][middle_idx]
        # For each streamline get idx for the point which is closest to the middle
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)

        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        base_idx = 1000  # use higher index to avoid negative numbers for area below middle
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            # indices for one streamline e.g. [998, 999, 1000, 1001, 1002, 1003]; 1000 is middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

        # Calcuate maximum number of indices to not result in more indices than nr_points.
        # (this could be case if one streamline is very off-center and therefore has a lot of points only on one
        # side. In this case the values too far out of this streamline will be cut off).
        max_idx = base_idx + int(nr_points / 2)
        min_idx = base_idx - int(nr_points / 2)

        # Group by segment indices
        results_dict = defaultdict(list)
        for idx, sl in enumerate(values):
            for jdx, seg in enumerate(sl):
                current_idx = segment_idxs[idx][jdx]
                if current_idx >= min_idx and current_idx < max_idx:
                    results_dict[current_idx].append(seg)

        # If values missing fill up with centroid values
        if len(results_dict.keys()) < nr_points:
            print(
                "WARNING: found less than required points. Filling up with centroid values."
            )
            centroid_sl = [centroids[0]]
            centroid_sl = np.array(centroid_sl).T
            centroid_values = map_coordinates(scalar_img, centroid_sl, order=1)
            for idx, seg_idx in enumerate(range(min_idx, max_idx)):
                if len(results_dict[seg_idx]) == 0:
                    results_dict[seg_idx].append(
                        np.array(centroid_values).T[0, idx])

        # Aggregate by mean
        results_mean = []
        results_std = []
        for key in sorted(results_dict.keys()):
            value = results_dict[key]
            if len(value) > 0:
                results_mean.append(np.array(value).mean())
                results_std.append(np.array(value).std())
            else:
                print("WARNING: empty segment")
                results_mean.append(0)
                results_std.append(0)
        return results_mean, results_std

    elif algorithm == "afq":
        ### sampling + aggregation ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        streamlines = Streamlines(streamlines)
        weights = dsa.gaussian_weights(streamlines)
        results_mean = dsa.afq_profile(scalar_img,
                                       streamlines,
                                       affine=np.eye(4),
                                       weights=weights)
        results_std = np.zeros(nr_points)
        return results_mean, results_std