def plot_bundles_with_metric(bundle_path, endings_path, brain_mask_path, bundle, metrics, output_path, tracking_format="trk_legacy", show_color_bar=True): import seaborn as sns # import in function to avoid error if not installed (this is only needed in this function) from dipy.viz import actor, window from tractseg.libs import vtk_utils def _add_extra_point_to_last_streamline(sl): # Coloring broken as soon as all streamlines have same number of points -> why??? # Add one number to last streamline to make it have a different number sl[-1] = np.append(sl[-1], [sl[-1][-1]], axis=0) return sl # Settings NR_SEGMENTS = 100 ANTI_INTERPOL_MULT = 1 # increase number of points to avoid interpolation to blur the colors algorithm = "distance_map" # equal_dist | distance_map | cutting_plane # colors = np.array(sns.color_palette("coolwarm", NR_SEGMENTS)) # colormap blue to red (does not fit to colorbar) colors = np.array(sns.light_palette( "red", NR_SEGMENTS)) # colormap only red, which fits to color_bar img_size = (1000, 1000) # Tractometry skips first and last element. Therefore we only have 98 instead of 100 elements. # Here we duplicate the first and last element to get back to 100 elements metrics = list(metrics) metrics = np.array([metrics[0]] + metrics + [metrics[-1]]) metrics_max = metrics.max() metrics_min = metrics.min() if metrics_max == metrics_min: metrics = np.zeros(len(metrics)) else: metrics = img_utils.scale_to_range( metrics, range=(0, 99)) # range needs to be same as segments in colormap orientation = dataset_specific_utils.get_optimal_orientation_for_bundle( bundle) # Load mask beginnings_img = nib.load(endings_path) beginnings = beginnings_img.get_data() for i in range(1): beginnings = binary_dilation(beginnings) # Load trackings if tracking_format == "trk_legacy": streams, hdr = trackvis.read(bundle_path) streamlines = [s[0] for s in streams] else: sl_file = nib.streamlines.load(bundle_path) streamlines = sl_file.streamlines # Reduce streamline count streamlines = streamlines[::2] # Reorder to make all streamlines have same start region streamlines = fiber_utils.add_to_each_streamline(streamlines, 0.5) streamlines_new = [] for idx, sl in enumerate(streamlines): startpoint = sl[0] # Flip streamline if not in right order if beginnings[int(startpoint[0]), int(startpoint[1]), int(startpoint[2])] == 0: sl = sl[::-1, :] streamlines_new.append(sl) streamlines = fiber_utils.add_to_each_streamline(streamlines_new, -0.5) if algorithm == "distance_map" or algorithm == "equal_dist": streamlines = fiber_utils.resample_fibers( streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT) elif algorithm == "cutting_plane": streamlines = fiber_utils.resample_to_same_distance( streamlines, max_nr_points=NR_SEGMENTS, ANTI_INTERPOL_MULT=ANTI_INTERPOL_MULT) # Cut start and end by percentage # streamlines = FiberUtils.resample_fibers(streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT) # remove = int((NR_SEGMENTS * ANTI_INTERPOL_MULT) * 0.15) # remove X% in beginning and end # streamlines = np.array(streamlines)[:, remove:-remove, :] # streamlines = list(streamlines) if algorithm == "equal_dist": segment_idxs = [] for i in range(len(streamlines)): segment_idxs.append(list(range(NR_SEGMENTS * ANTI_INTERPOL_MULT))) segment_idxs = np.array(segment_idxs) elif algorithm == "distance_map": metric = AveragePointwiseEuclideanMetric() qb = QuickBundles(threshold=100., metric=metric) clusters = qb.cluster(streamlines) centroids = Streamlines(clusters.centroids) _, segment_idxs = cKDTree(centroids.data, 1, copy_data=True).query(streamlines, k=1) elif algorithm == "cutting_plane": streamlines_resamp = fiber_utils.resample_fibers( streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT) metric = AveragePointwiseEuclideanMetric() qb = QuickBundles(threshold=100., metric=metric) clusters = qb.cluster(streamlines_resamp) centroid = Streamlines(clusters.centroids)[0] # index of the middle cluster middle_idx = int(NR_SEGMENTS / 2) * ANTI_INTERPOL_MULT middle_point = centroid[middle_idx] segment_idxs = fiber_utils.get_idxs_of_closest_points( streamlines, middle_point) # Align along the middle and assign indices segment_idxs_eqlen = [] for idx, sl in enumerate(streamlines): sl_middle_pos = segment_idxs[idx] before_elems = sl_middle_pos after_elems = len(sl) - sl_middle_pos base_idx = 1000 # use higher index to avoid negative numbers for area below middle r = range((base_idx - before_elems), (base_idx + after_elems)) segment_idxs_eqlen.append(r) segment_idxs = segment_idxs_eqlen # Add extra point otherwise coloring BUG streamlines = _add_extra_point_to_last_streamline(streamlines) renderer = window.Renderer() colors_all = [] # final shape will be [nr_streamlines, nr_points, 3] for jdx, sl in enumerate(streamlines): colors_sl = [] for idx, p in enumerate(sl): if idx >= len(segment_idxs[jdx]): seg_idx = segment_idxs[jdx][idx - 1] else: seg_idx = segment_idxs[jdx][idx] m = metrics[int(seg_idx / ANTI_INTERPOL_MULT)] color = colors[int(m)] colors_sl.append(color) colors_all.append( colors_sl ) # this can not be converted to numpy array because last element has one more elem sl_actor = actor.streamtube(streamlines, colors=colors_all, linewidth=0.2, opacity=1) renderer.add(sl_actor) # plot brain mask mask = nib.load(brain_mask_path).get_data() cont_actor = vtk_utils.contour_from_roi_smooth( mask, affine=beginnings_img.affine, color=[.9, .9, .9], opacity=.2, smoothing=50) renderer.add(cont_actor) if show_color_bar: lut_cmap = actor.colormap_lookup_table(scale_range=(metrics_min, metrics_max), hue_range=(0.0, 0.0), saturation_range=(0.0, 1.0)) renderer.add(actor.scalar_bar(lut_cmap)) if orientation == "sagittal": renderer.set_camera(position=(-412.95, -34.38, 80.15), focal_point=(102.46, -16.96, -11.71), view_up=(0.1806, 0.0, 0.9835)) elif orientation == "coronal": renderer.set_camera(position=(-48.63, 360.31, 98.37), focal_point=(-20.16, 92.89, 36.02), view_up=(-0.0047, -0.2275, 0.9737)) elif orientation == "axial": pass else: raise ValueError("Invalid orientation provided") # Use this to interatively get new camera angle # window.show(renderer, size=img_size, reset_camera=False) # print(renderer.get_camera()) window.record(renderer, out_path=output_path, size=img_size)
def evaluate_along_streamlines(scalar_img, streamlines, beginnings, nr_points, dilate=0, predicted_peaks=None, affine=None): # Runtime: # - default: 2.7s (test), 56s (all), 10s (test 4 bundles, 100 points) # - map_coordinate order 1: 1.9s (test), 26s (all), 6s (test 4 bundles, 100 points) # - map_coordinate order 3: 2.2s (test), 33s (all), # - values_from_volume: 2.5s (test), 43s (all), # - AFQ: ?s (test), ?s (all), 85s (test 4 bundles, 100 points) # => AFQ a lot slower than others streamlines = list( transform_streamlines(streamlines, np.linalg.inv(affine))) for i in range(dilate): beginnings = binary_dilation(beginnings) beginnings = beginnings.astype(np.uint8) streamlines = _orient_to_same_start_region(streamlines, beginnings) if predicted_peaks is not None: # scalar img can also be orig peaks best_orig_peaks = fiber_utils.get_best_original_peaks( predicted_peaks, scalar_img, peak_len_thr=0.00001) scalar_img = np.linalg.norm(best_orig_peaks, axis=-1) algorithm = "distance_map" # equal_dist | distance_map | cutting_plane | afq if algorithm == "equal_dist": ### Sampling ### streamlines = fiber_utils.resample_fibers(streamlines, nb_points=nr_points) values = map_coordinates(scalar_img, np.array(streamlines).T, order=1) ### Aggregation ### values_mean = np.array(values).mean(axis=1) values_std = np.array(values).std(axis=1) return values_mean, values_std if algorithm == "distance_map": # cKDTree ### Sampling ### streamlines = fiber_utils.resample_fibers(streamlines, nb_points=nr_points) values = map_coordinates(scalar_img, np.array(streamlines).T, order=1) ### Aggregating by cKDTree approach ### metric = AveragePointwiseEuclideanMetric() qb = QuickBundles(threshold=100., metric=metric) clusters = qb.cluster(streamlines) centroids = Streamlines(clusters.centroids) if len(centroids) > 1: print("WARNING: number clusters > 1 ({})".format(len(centroids))) _, segment_idxs = cKDTree(centroids.data, 1, copy_data=True).query(streamlines, k=1) # (2000, 100) values_t = np.array(values).T # (2000, 100) # If we want to take weighted mean like in AFQ: # weights = dsa.gaussian_weights(Streamlines(streamlines)) # values_t = weights * values_t # return np.sum(values_t, 0), None results_dict = defaultdict(list) for idx, sl in enumerate(values_t): for jdx, seg in enumerate(sl): results_dict[segment_idxs[idx, jdx]].append(seg) if len(results_dict.keys()) < nr_points: print( "WARNING: found less than required points. Filling up with centroid values." ) centroid_values = map_coordinates(scalar_img, np.array([centroids[0]]).T, order=1) for i in range(nr_points): if len(results_dict[i]) == 0: results_dict[i].append(np.array(centroid_values).T[0, i]) results_mean = [] results_std = [] for key in sorted(results_dict.keys()): value = results_dict[key] if len(value) > 0: results_mean.append(np.array(value).mean()) results_std.append(np.array(value).std()) else: print("WARNING: empty segment") results_mean.append(0) results_std.append(0) return results_mean, results_std elif algorithm == "cutting_plane": # This will resample all streamline to have equally distant points (resulting in a different number of points # in each streamline). Then the "middle" of the tract will be estimated taking the middle element of the # centroid (estimated with QuickBundles). Then each streamline the point closest to the "middle" will be # calculated and points will be indexed for each streamline starting from the middle. Then averaging across # all streamlines will be done by taking the mean for points with same indices. ### Sampling ### streamlines = fiber_utils.resample_to_same_distance( streamlines, max_nr_points=nr_points) # map_coordinates does not allow streamlines with different lengths -> use values_from_volume values = np.array( values_from_volume(scalar_img, streamlines, affine=np.eye(4))).T ### Aggregating by Cutting Plane approach ### # Resample to all fibers having same number of points -> needed for QuickBundles streamlines_resamp = fiber_utils.resample_fibers(streamlines, nb_points=nr_points) metric = AveragePointwiseEuclideanMetric() qb = QuickBundles(threshold=100., metric=metric) clusters = qb.cluster(streamlines_resamp) centroids = Streamlines(clusters.centroids) # index of the middle cluster middle_idx = int(nr_points / 2) middle_point = centroids[0][middle_idx] # For each streamline get idx for the point which is closest to the middle segment_idxs = fiber_utils.get_idxs_of_closest_points( streamlines, middle_point) # Align along the middle and assign indices segment_idxs_eqlen = [] base_idx = 1000 # use higher index to avoid negative numbers for area below middle for idx, sl in enumerate(streamlines): sl_middle_pos = segment_idxs[idx] before_elems = sl_middle_pos after_elems = len(sl) - sl_middle_pos # indices for one streamline e.g. [998, 999, 1000, 1001, 1002, 1003]; 1000 is middle r = range((base_idx - before_elems), (base_idx + after_elems)) segment_idxs_eqlen.append(r) segment_idxs = segment_idxs_eqlen # Calcuate maximum number of indices to not result in more indices than nr_points. # (this could be case if one streamline is very off-center and therefore has a lot of points only on one # side. In this case the values too far out of this streamline will be cut off). max_idx = base_idx + int(nr_points / 2) min_idx = base_idx - int(nr_points / 2) # Group by segment indices results_dict = defaultdict(list) for idx, sl in enumerate(values): for jdx, seg in enumerate(sl): current_idx = segment_idxs[idx][jdx] if current_idx >= min_idx and current_idx < max_idx: results_dict[current_idx].append(seg) # If values missing fill up with centroid values if len(results_dict.keys()) < nr_points: print( "WARNING: found less than required points. Filling up with centroid values." ) centroid_sl = [centroids[0]] centroid_sl = np.array(centroid_sl).T centroid_values = map_coordinates(scalar_img, centroid_sl, order=1) for idx, seg_idx in enumerate(range(min_idx, max_idx)): if len(results_dict[seg_idx]) == 0: results_dict[seg_idx].append( np.array(centroid_values).T[0, idx]) # Aggregate by mean results_mean = [] results_std = [] for key in sorted(results_dict.keys()): value = results_dict[key] if len(value) > 0: results_mean.append(np.array(value).mean()) results_std.append(np.array(value).std()) else: print("WARNING: empty segment") results_mean.append(0) results_std.append(0) return results_mean, results_std elif algorithm == "afq": ### sampling + aggregation ### streamlines = fiber_utils.resample_fibers(streamlines, nb_points=nr_points) streamlines = Streamlines(streamlines) weights = dsa.gaussian_weights(streamlines) results_mean = dsa.afq_profile(scalar_img, streamlines, affine=np.eye(4), weights=weights) results_std = np.zeros(nr_points) return results_mean, results_std