Esempio n. 1
0
def plot_mask(renderer, mask_data, affine, x_current, y_current,
              orientation="axial", smoothing=10, brain_mask=None):
    from tractseg.libs import vtk_utils

    if brain_mask is not None:
        brain_mask = brain_mask.transpose(0, 2, 1)
        brain_mask = brain_mask[::-1, :, :]
        if orientation == "sagittal":
            brain_mask = brain_mask.transpose(2, 1, 0)
            brain_mask = brain_mask[::-1, :, :]
        cont_actor = vtk_utils.contour_from_roi_smooth(brain_mask, affine=affine,
                                                       color=[.9, .9, .9], opacity=.1, smoothing=30)
        cont_actor.SetPosition(x_current, y_current, 0)
        renderer.add(cont_actor)

    # 3D Bundle
    mask = mask_data
    mask = mask.transpose(0, 2, 1)
    mask = mask[::-1, :, :]
    if orientation == "sagittal":
        mask = mask.transpose(2, 1, 0)
        mask = mask[::-1, :, :]
    color = [1, .27, .18]  # red

    cont_actor = vtk_utils.contour_from_roi_smooth(mask, affine=affine,
                                                   color=color, opacity=1, smoothing=smoothing)
    cont_actor.SetPosition(x_current, y_current, 0)
    renderer.add(cont_actor)
Esempio n. 2
0
def plot_bundles_with_metric(bundle_path,
                             endings_path,
                             brain_mask_path,
                             bundle,
                             metrics,
                             output_path,
                             tracking_format="trk_legacy",
                             show_color_bar=True):
    import seaborn as sns  # import in function to avoid error if not installed (this is only needed in this function)
    from dipy.viz import actor, window
    from tractseg.libs import vtk_utils

    def _add_extra_point_to_last_streamline(sl):
        # Coloring broken as soon as all streamlines have same number of points -> why???
        # Add one number to last streamline to make it have a different number
        sl[-1] = np.append(sl[-1], [sl[-1][-1]], axis=0)
        return sl

    # Settings
    NR_SEGMENTS = 100
    ANTI_INTERPOL_MULT = 1  # increase number of points to avoid interpolation to blur the colors
    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane
    # colors = np.array(sns.color_palette("coolwarm", NR_SEGMENTS))  # colormap blue to red (does not fit to colorbar)
    colors = np.array(sns.light_palette(
        "red", NR_SEGMENTS))  # colormap only red, which fits to color_bar
    img_size = (1000, 1000)

    # Tractometry skips first and last element. Therefore we only have 98 instead of 100 elements.
    # Here we duplicate the first and last element to get back to 100 elements
    metrics = list(metrics)
    metrics = np.array([metrics[0]] + metrics + [metrics[-1]])

    metrics_max = metrics.max()
    metrics_min = metrics.min()
    if metrics_max == metrics_min:
        metrics = np.zeros(len(metrics))
    else:
        metrics = img_utils.scale_to_range(
            metrics,
            range=(0, 99))  # range needs to be same as segments in colormap

    orientation = dataset_specific_utils.get_optimal_orientation_for_bundle(
        bundle)

    # Load mask
    beginnings_img = nib.load(endings_path)
    beginnings = beginnings_img.get_data()
    for i in range(1):
        beginnings = binary_dilation(beginnings)

    # Load trackings
    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(bundle_path)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(bundle_path)
        streamlines = sl_file.streamlines

    # Reduce streamline count
    streamlines = streamlines[::2]

    # Reorder to make all streamlines have same start region
    streamlines = fiber_utils.add_to_each_streamline(streamlines, 0.5)
    streamlines_new = []
    for idx, sl in enumerate(streamlines):
        startpoint = sl[0]
        # Flip streamline if not in right order
        if beginnings[int(startpoint[0]),
                      int(startpoint[1]),
                      int(startpoint[2])] == 0:
            sl = sl[::-1, :]
        streamlines_new.append(sl)
    streamlines = fiber_utils.add_to_each_streamline(streamlines_new, -0.5)

    if algorithm == "distance_map" or algorithm == "equal_dist":
        streamlines = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    elif algorithm == "cutting_plane":
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines,
            max_nr_points=NR_SEGMENTS,
            ANTI_INTERPOL_MULT=ANTI_INTERPOL_MULT)

    # Cut start and end by percentage
    # streamlines = FiberUtils.resample_fibers(streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    # remove = int((NR_SEGMENTS * ANTI_INTERPOL_MULT) * 0.15)  # remove X% in beginning and end
    # streamlines = np.array(streamlines)[:, remove:-remove, :]
    # streamlines = list(streamlines)

    if algorithm == "equal_dist":
        segment_idxs = []
        for i in range(len(streamlines)):
            segment_idxs.append(list(range(NR_SEGMENTS * ANTI_INTERPOL_MULT)))
        segment_idxs = np.array(segment_idxs)

    elif algorithm == "distance_map":
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines, k=1)

    elif algorithm == "cutting_plane":
        streamlines_resamp = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroid = Streamlines(clusters.centroids)[0]
        # index of the middle cluster
        middle_idx = int(NR_SEGMENTS / 2) * ANTI_INTERPOL_MULT
        middle_point = centroid[middle_idx]
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)
        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            base_idx = 1000  # use higher index to avoid negative numbers for area below middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

    # Add extra point otherwise coloring BUG
    streamlines = _add_extra_point_to_last_streamline(streamlines)

    renderer = window.Renderer()
    colors_all = []  # final shape will be [nr_streamlines, nr_points, 3]
    for jdx, sl in enumerate(streamlines):
        colors_sl = []
        for idx, p in enumerate(sl):
            if idx >= len(segment_idxs[jdx]):
                seg_idx = segment_idxs[jdx][idx - 1]
            else:
                seg_idx = segment_idxs[jdx][idx]

            m = metrics[int(seg_idx / ANTI_INTERPOL_MULT)]
            color = colors[int(m)]
            colors_sl.append(color)
        colors_all.append(
            colors_sl
        )  # this can not be converted to numpy array because last element has one more elem

    sl_actor = actor.streamtube(streamlines,
                                colors=colors_all,
                                linewidth=0.2,
                                opacity=1)
    renderer.add(sl_actor)

    # plot brain mask
    mask = nib.load(brain_mask_path).get_data()
    cont_actor = vtk_utils.contour_from_roi_smooth(
        mask,
        affine=beginnings_img.affine,
        color=[.9, .9, .9],
        opacity=.2,
        smoothing=50)
    renderer.add(cont_actor)

    if show_color_bar:
        lut_cmap = actor.colormap_lookup_table(scale_range=(metrics_min,
                                                            metrics_max),
                                               hue_range=(0.0, 0.0),
                                               saturation_range=(0.0, 1.0))
        renderer.add(actor.scalar_bar(lut_cmap))

    if orientation == "sagittal":
        renderer.set_camera(position=(-412.95, -34.38, 80.15),
                            focal_point=(102.46, -16.96, -11.71),
                            view_up=(0.1806, 0.0, 0.9835))
    elif orientation == "coronal":
        renderer.set_camera(position=(-48.63, 360.31, 98.37),
                            focal_point=(-20.16, 92.89, 36.02),
                            view_up=(-0.0047, -0.2275, 0.9737))
    elif orientation == "axial":
        pass
    else:
        raise ValueError("Invalid orientation provided")

    # Use this to interatively get new camera angle
    # window.show(renderer, size=img_size, reset_camera=False)
    # print(renderer.get_camera())

    window.record(renderer, out_path=output_path, size=img_size)