def _load_streamlines_from_hdf(hdf_group: h5py.Group): streamlines = ArraySequence() streamlines._data = np.array(hdf_group['data']) streamlines._offsets = np.array(hdf_group['offsets']) streamlines._lengths = np.array(hdf_group['lengths']) return streamlines
def read_tracksi(self, indices): """ read tracks with specific indices """ tracks = Streamlines() for i in indices: off0, off1 = self.offsets[i:i + 2] tracks.append(self.tracks[off0:off1]) return tracks
def read_tracks(self): """ read the entire tractography """ I = self.offsets[:] TR = self.tracks[:] tracks = Streamlines() for i in range(len(I) - 1): off0, off1 = I[i:i + 2] tracks.append(TR[off0:off1]) return tracks
def get_seeds_from_wm(wm_path, threshold=0): wm_file = nib.load(wm_path) wm_img = wm_file.get_fdata() seeds = np.argwhere(wm_img > threshold) seeds = np.hstack([seeds, np.ones([len(seeds), 1])]) seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3) n_seeds = len(seeds) header = TrkFile.create_empty_header() header["voxel_to_rasmm"] = wm_file.affine header["dimensions"] = wm_file.header["dim"][1:4] header["voxel_sizes"] = wm_file.header["pixdim"][1:4] header["voxel_order"] = get_reference_info(wm_file)[3] tractogram = Tractogram(streamlines=ArraySequence(seeds), affine_to_rasmm=np.eye(4)) save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk") print("Saving {}".format(save_path)) TrkFile(tractogram, header).save(save_path)
def _compute_streamline_mean(cur_ind, cur_min, cur_max, data): # From the precomputed indices, compute the binary map # and use it to weight the metric data for this specific streamline. cur_range = tuple(cur_max - cur_min) streamline_density = compute_tract_counts_map(ArraySequence([cur_ind]), cur_range) streamline_data = data[cur_min[0]:cur_max[0], cur_min[1]:cur_max[1], cur_min[2]:cur_max[2]] streamline_average = np.average(streamline_data, weights=streamline_density) return streamline_average
def multiprocess_subsampling(args): streamlines = args[0] min_distance = args[1] cluster_thr = args[2] min_cluster_size = args[3] average_streamlines = args[4] min_cluster_size = max(min_cluster_size, 1) thresholds = [40, 30, 20, cluster_thr] cluster_map = qbx_and_merge(ArraySequence(streamlines), thresholds, nb_pts=20, verbose=False) return subsample_clusters(cluster_map, streamlines, min_distance, min_cluster_size, average_streamlines)
def _process_streamlines(streamlines): # Compute the bounding boxes and indices for all streamlines. mins = [] maxs = [] offset_streamlines = [] # Offset the streamlines to compute the indices only in the bounding box. # Reduces memory use later on. for idx, s in enumerate(streamlines): mins.append(np.min(s.astype(int), 0)) maxs.append(np.max(s.astype(int), 0) + 1) offset_streamlines.append((s - mins[-1]).astype(np.float32)) offset_streamlines = ArraySequence(offset_streamlines) indices = uncompress(offset_streamlines) return mins, maxs, indices
def main(): parser = build_parser() args = parser.parse_args() print(args) with Timer("Loading streamlines"): trk = nib.streamlines.load(args.tractogram) losses = trk.tractogram.data_per_streamline['loss'] del trk.tractogram.data_per_streamline['loss'] # Not supported in MI-Brain for my version. with Timer("Coloring streamlines"): viridis = plt.get_cmap('RdYlGn') losses = -losses[:, 0] losses -= losses.mean() vmin = losses.min() vmax = losses.max() if args.normalization == "norm": cNorm = colors.Normalize(vmin=vmin, vmax=vmax) elif args.normalization == "log": cNorm = colors.LogNorm(vmin=vmin, vmax=vmax) elif args.normalization == "symlog": cNorm = colors.SymLogNorm(linthresh=0.03, linscale=1, vmin=vmin, vmax=vmax) else: raise ValueError("Unkown normalization: {}".format(args.normalization)) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=viridis) print(scalarMap.get_clim()) # losses -= losses.mean() # losses /= losses.std() streamlines_colors = scalarMap.to_rgba(losses, bytes=True)[:, :-1] # from dipy.viz import fvtk # streamlines_colors = fvtk.create_colormap(-losses[:, 0]) * 255 colors_per_point = ArraySequence([np.tile(c, (len(s), 1)) for s, c in zip(trk.tractogram.streamlines, streamlines_colors)]) trk.tractogram.data_per_point['color'] = colors_per_point with Timer("Saving streamlines"): if args.out is None: args.out = args.tractogram[:-4] + "_color_" + args.normalization + args.tractogram[-4:] nib.streamlines.save(trk.tractogram, args.out)
def tracking(image, bvecs, bvals, wm, seeds, fibers, prune_length=3, rseed=42, plot=False, proba=False, verbose=False): # Pipelines transcribed from: # https://dipy.org/documentation/1.1.1./examples_built/tracking_introduction_eudx/#example-tracking-introduction-eudx # https://dipy.org/documentation/1.1.1./examples_built/tracking_probabilistic/ # Load Images dwi_loaded = nib.load(image) dwi_data = dwi_loaded.get_fdata() wm_loaded = nib.load(wm) wm_data = wm_loaded.get_fdata() seeds_loaded = nib.load(seeds) seeds_data = seeds_loaded.get_fdata() seeds = utils.seeds_from_mask(seeds_data, dwi_loaded.affine, density=2) # Load B-values & B-vectors # NB. Use aligned b-vecs if providing eddy-aligned data bvals, bvecs = read_bvals_bvecs(bvals, bvecs) gtab = gradient_table(bvals, bvecs) csa_model = CsaOdfModel(gtab, sh_order=6) # Set stopping criterion gfa = csa_model.fit(dwi_data, mask=wm_data).gfa stop_criterion = ThresholdStoppingCriterion(gfa, .25) if proba: # Establish ODF model response, ratio = auto_response(gtab, dwi_data, roi_radius=10, fa_thr=0.7) csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6) csd_fit = csd_model.fit(dwi_data, mask=wm_data) # Create Probabilisitic direction getter fod = csd_fit.odf(default_sphere) pmf = fod.clip(min=0) prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf, max_angle=30., sphere=default_sphere) # Use the probabilisitic direction getter as the dg dg = prob_dg else: # Establish ODF model csa_peaks = peaks_from_model(csa_model, dwi_data, default_sphere, relative_peak_threshold=0.8, min_separation_angle=45, mask=wm_data) # Use the CSA peaks as the dg dg = csa_peaks # Create generator and perform tracing s_generator = LocalTracking(dg, stop_criterion, seeds, dwi_loaded.affine, 0.5, random_seed=rseed) streamlines = Streamlines(s_generator) # Prune streamlines streamlines = ArraySequence( [strline for strline in streamlines if len(strline) > prune_length]) sft = StatefulTractogram(streamlines, dwi_loaded, Space.RASMM) # Save streamlines save_trk(sft, fibers + ".trk") # Visualize fibers if plot and has_fury: from dipy.viz import window, actor, colormap as cmap # Create the 3D display. r = window.Renderer() r.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(r, out_path=fibers + '.png', size=(800, 800))
def extract_true_connections( sft, mask_1_filename, mask_2_filename, gt_config, length_dict, gt_bundle, gt_bundle_inv_mask, dilate_endpoints, wrong_path_as_separate ): """ Extract true connections based on two regions from a tractogram. May extract false and no connections if the config is passed. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. mask_1_filename: str Filename of the "head" of the bundle. mask_2_filename: str Filename of the "tail" of the bundle. gt_config: dict or None Dictionary containing the bundle's parameters. length_dict: dict or None Dictionary containing the bundle's length parameters. gt_bundle: str Bundle's name. gt_bundle_inv_mask: np.ndarray Inverse mask of the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. wrong_path_as_separate: bool If true, save the WPCs as separate from TCs. Returns ------- tc_sft: StatefulTractogram SFT of true connections. wpc_sft: StatefulTractogram SFT of wrong-path-connections. fc_sft: StatefulTractogram SFT of false connections (streamlines that are too long). nc_streamlines: StatefulTractogram SFT of no connections (streamlines that loop) sft: StatefulTractogram SFT of remaining streamlines. """ mask_1_img = nib.load(mask_1_filename) mask_2_img = nib.load(mask_2_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) # TODO: Handle streamline IDs instead of streamlines tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft) streamlines = tmp_sft.streamlines tc_streamlines = streamlines wpc_streamlines = [] fc_streamlines = [] nc_streamlines = [] # Config file for each 'bundle' # Loops => no connection (nc) # TODO Is this legit ? # Length => false connection (fc) # TODO Is this legit ? if gt_config: min_len, max_len = \ length_dict[gt_bundle]['length'] # Bring streamlines to world coordinates so proper length # is calculated tmp_sft.to_rasmm() streamlines = tmp_sft.streamlines lengths = np.array(list(length(streamlines))) tmp_sft.to_vox() streamlines = tmp_sft.streamlines valid_min_length_mask = lengths > min_len valid_max_length_mask = lengths < max_len valid_length_mask = np.logical_and(valid_min_length_mask, valid_max_length_mask) streamlines = ArraySequence(streamlines) val_len_streamlines = streamlines[valid_length_mask] fc_streamlines = streamlines[~valid_length_mask] angle = length_dict[gt_bundle]['angle'] tc_streamlines_ids = remove_loops_and_sharp_turns( val_len_streamlines, angle) loop_ids = np.setdiff1d( range(len(val_len_streamlines)), tc_streamlines_ids) loops = val_len_streamlines[list(loop_ids)] tc_streamlines = val_len_streamlines[list(tc_streamlines_ids)] if loops: nc_streamlines = loops # Streamlines getting out of the bundle mask can be considered # separately as wrong path connection (wpc) # TODO: Maybe only consider if they cross another GT bundle ? if wrong_path_as_separate: tmp_sft = StatefulTractogram.from_sft(tc_streamlines, sft) _, wp_ids = filter_grid_roi( tmp_sft, gt_bundle_inv_mask, 'any', False) wpc_streamlines = tmp_sft.streamlines[list(wp_ids)] tc_ids = np.setdiff1d(range(len(tmp_sft)), wp_ids) tc_streamlines = tmp_sft.streamlines[list(tc_ids)] tc_sft = StatefulTractogram.from_sft(tc_streamlines, sft) wpc_sft = StatefulTractogram.from_sft([], sft) fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft) if wrong_path_as_separate and len(wpc_streamlines): wpc_sft = StatefulTractogram.from_sft(wpc_streamlines, sft) return tc_sft, wpc_sft, fc_sft, nc_streamlines, sft
def evaluation_tractogram(hyperparams, model, dataset, batch_size_override, metric): loss = loss_factory(hyperparams, model, dataset, loss_type=None) batch_scheduler = batch_scheduler_factory( hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False) _ = loss.losses # Hack to generate update dict in loss :( if hyperparams['model'] == 'ffnn_regression': timestep_losses, inputs, targets = log_variables( batch_scheduler, model, loss.loss_per_time_step, dataset.symb_inputs * 1, dataset.symb_targets * 1) # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented timesteps_loss = ArraySequence() seq_loss = [] timesteps_inputs = ArraySequence() timesteps_targets = ArraySequence() idx = 0 for length in dataset.streamlines._lengths: start = idx idx = end = idx + length timesteps_loss.extend(timestep_losses[start:end]) seq_loss.extend(np.mean(timestep_losses[start:end])) timesteps_inputs.extend(inputs[start:end]) timesteps_targets.extend(targets[start:end]) else: timestep_losses, seq_losses, inputs, targets, masks = log_variables( batch_scheduler, model, loss.loss_per_time_step, loss.loss_per_seq, dataset.symb_inputs * 1, dataset.symb_targets * 1, dataset.symb_mask * 1) timesteps_loss = ArraySequence([ l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks)) ]) seq_loss = np.array(list(chain(*seq_losses))) timesteps_inputs = ArraySequence( [i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))]) # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension timesteps_targets = ArraySequence([ np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks)) ]) if metric == 'sequence': # Color is based on sequence loss values = seq_loss elif metric == 'timestep' or metric == 'cumul_avg': # Color is based on timestep loss values = np.concatenate(timesteps_loss) else: raise ValueError("Unrecognized metric: {}".format(metric)) cmap = cm.get_cmap('bwr') vmin = np.percentile(values, 5) vmax = np.percentile(values, 95) scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap) streamlines = [] colors = [] for i, t, l, seq_l in zip(timesteps_inputs, timesteps_targets, timesteps_loss, seq_loss): pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]] color = np.zeros_like(pts) if metric == 'sequence': # Streamline color is based on sequence loss color[:, :] = scalar_map.to_rgba(seq_l, bytes=True)[:3] elif metric == 'timestep': # Streamline color is based on timestep loss # Identify first point with green color[0, :] = [0, 255, 0] color[1:, :] = scalar_map.to_rgba(l, bytes=True)[:, :3] elif metric == 'cumul_avg': # Streamline color is based on timestep loss # Compute cumulative average cumul_avg = np.cumsum(l) / np.arange(1, len(l) + 1) # Identify first point with green color[0, :] = [0, 255, 0] color[1:, :] = scalar_map.to_rgba(cumul_avg, bytes=True)[:, :3] else: raise ValueError("Unrecognized metric: {}".format(metric)) streamlines.append(pts) colors.append(color) tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors}) return tractogram
def prediction_tractogram(hyperparams, model, dataset, batch_size_override, prediction_method): loss = loss_factory(hyperparams, model, dataset, loss_type=prediction_method) batch_scheduler = batch_scheduler_factory( hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False) _ = loss.losses # Hack to generate update dict in loss :( predictions = loss.samples predict, timestep_losses, inputs, targets, masks = log_variables( batch_scheduler, model, predictions, loss.loss_per_time_step, dataset.symb_inputs * 1, dataset.symb_targets * 1, dataset.symb_mask * 1) if hyperparams['model'] == 'ffnn_regression': # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented timesteps_prediction = ArraySequence() timesteps_loss = ArraySequence() timesteps_inputs = ArraySequence() timesteps_targets = ArraySequence() idx = 0 for length in dataset.streamlines._lengths: start = idx idx = end = idx + length timesteps_prediction.extend(predict[start:end]) timesteps_loss.extend(timestep_losses[start:end]) timesteps_inputs.extend(inputs[start:end]) timesteps_targets.extend(targets[start:end]) else: timesteps_prediction = ArraySequence( [p[:int(m.sum())] for p, m in zip(chain(*predict), chain(*masks))]) timesteps_loss = ArraySequence([ l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks)) ]) timesteps_inputs = ArraySequence( [i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))]) # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension timesteps_targets = ArraySequence([ np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks)) ]) # Debug : Print norm stats # print("Dataset: {}; # of streamlines: {}".format(dataset.name, len(dataset))) # all_predictions = np.array(list(chain(*timesteps_prediction))) # prediction_norms = np.linalg.norm(all_predictions, axis=1) # print("Prediction norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(prediction_norms), np.max(prediction_norms), np.min(prediction_norms))) # all_targets = np.array(list(chain(*timesteps_targets))) # target_norms = np.linalg.norm(all_targets, axis=1) # print("Target norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(target_norms), np.max(target_norms), np.min(target_norms))) # Color is based on timestep loss cmap = cm.get_cmap('bwr') values = np.concatenate(timesteps_loss) vmin = np.percentile(values, 5) vmax = np.percentile(values, 95) scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap) streamlines = [] colors = [] for i, t, p, l in zip(timesteps_inputs, timesteps_targets, timesteps_prediction, timesteps_loss): pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]] streamline = np.zeros(((len(pts) - 1) * 3 + 1, 3)) streamline[::3] = pts streamline[1:-1:3] = pts[:-1] + p streamline[2:-1:3] = pts[:-1] streamlines.append(streamline) # Color input streamlines in a uniform color, then color predictions based on L2 error color = np.zeros_like(streamline) # Base color of streamlines is minimum value (best score) color[:] = scalar_map.to_rgba(vmin, bytes=True)[:3] color[1:-1:3, :] = scalar_map.to_rgba(l, bytes=True)[:, :3] colors.append(color) tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors}) return tractogram
def get_array_sequence(self, item=None): if item is None: streamlines = _load_streamlines_from_hdf(self.hdf_group) else: streamlines = ArraySequence() if isinstance(item, int): streamline = self._get_one_streamline(item) streamlines.append(streamline) elif isinstance(item, list) or isinstance(item, np.ndarray): for i in item: streamline = self._get_one_streamline(i) streamlines.append(streamline, cache_build=True) streamlines.finalize_append() elif isinstance(item, slice): offsets = self.hdf_group['offsets'][item] lengths = self.hdf_group['lengths'][item] for offset, length in zip(offsets, lengths): streamline = self.hdf_group['data'][offset:offset + length] streamlines.append(streamline, cache_build=True) streamlines.finalize_append() else: raise ValueError( 'Item should be either a int, list, ' 'np.ndarray or slice but we received {}'.format( type(item))) return streamlines
def dwi_deterministic_tracing(image, bvecs, bvals, wm, seeds, fibers, prune_length=3, plot=False, verbose=False): # Pipeline transcribed from: # http://nipy.org/dipy/examples_built/introduction_to_basic_tracking.html # Load Images dwi_loaded = nib.load(image) dwi_data = dwi_loaded.get_data() wm_loaded = nib.load(wm) wm_data = wm_loaded.get_data() seeds_loaded = nib.load(seeds) seeds_data = seeds_loaded.get_data() # Load B-values & B-vectors # NB. Use aligned b-vecs if providing eddy-aligned data bvals, bvecs = read_bvals_bvecs(bvals, bvecs) gtab = gradient_table(bvals, bvecs) # Establish ODF model csa_model = CsaOdfModel(gtab, sh_order=6) csa_peaks = peaks_from_model(csa_model, dwi_data, default_sphere, relative_peak_threshold=0.8, min_separation_angle=45, mask=wm_data) # Classify tissue for high FA and create seeds # (Putting this inside a looped try-block to handle fuzzy failures) classifier = ThresholdTissueClassifier(csa_peaks.gfa, 0.25) seeds = wrap_fuzzy_failures(utils.seeds_from_mask, args=[seeds_data], kwargs={"density": [2, 2, 2], "affine": np.eye(4)}, errortype=ValueError, failure_threshold=5, verbose=verbose) # Perform deterministic tracing # (Putting this inside a looped try-block to handle fuzzy failures) streamlines_generator = wrap_fuzzy_failures(LocalTracking, args=[csa_peaks, classifier, seeds], kwargs={"affine": np.eye(4), "step_size": 0.5}, errortype=ValueError, failure_threshold=5, verbose=verbose) streamlines = wrap_fuzzy_failures(Streamlines, args=[streamlines_generator], kwargs={}, errortype=IndexError, failure_threshold=5, verbose=verbose) # Prune streamlines streamlines = ArraySequence([strline for strline in streamlines if len(strline) > prune_length]) # Save streamlines save_trk(fibers + ".trk", streamlines, dwi_loaded.affine, shape=wm_data.shape, vox_size=wm_loaded.header.get_zooms()) # Visualize fibers if plot and have_fury: from dipy.viz import window, actor, colormap as cmap color = cmap.line_colors(streamlines) streamlines_actor = actor.line(streamlines, color) # Create the 3D display. r = window.Renderer() r.add(streamlines_actor) # Save still image. window.record(r, n_frames=1, out_path=fibers + ".png", size=(800, 800))
def prediction_tractogram(hyperparams, model, dataset, batch_size_override, prediction_method): loss = loss_factory(hyperparams, model, dataset, loss_type=prediction_method) batch_scheduler = batch_scheduler_factory(hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False) _ = loss.losses # Hack to generate update dict in loss :( predictions = loss.samples predict, timestep_losses, inputs, targets, masks = log_variables(batch_scheduler, model, predictions, loss.loss_per_time_step, dataset.symb_inputs * 1, dataset.symb_targets * 1, dataset.symb_mask * 1) if hyperparams['model'] == 'ffnn_regression': # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented timesteps_prediction = ArraySequence() timesteps_loss = ArraySequence() timesteps_inputs = ArraySequence() timesteps_targets = ArraySequence() idx = 0 for length in dataset.streamlines._lengths: start = idx idx = end = idx+length timesteps_prediction.extend(predict[start:end]) timesteps_loss.extend(timestep_losses[start:end]) timesteps_inputs.extend(inputs[start:end]) timesteps_targets.extend(targets[start:end]) else: timesteps_prediction = ArraySequence([p[:int(m.sum())] for p, m in zip(chain(*predict), chain(*masks))]) timesteps_loss = ArraySequence([l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks))]) timesteps_inputs = ArraySequence([i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))]) # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension timesteps_targets = ArraySequence([np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks))]) # Debug : Print norm stats # print("Dataset: {}; # of streamlines: {}".format(dataset.name, len(dataset))) # all_predictions = np.array(list(chain(*timesteps_prediction))) # prediction_norms = np.linalg.norm(all_predictions, axis=1) # print("Prediction norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(prediction_norms), np.max(prediction_norms), np.min(prediction_norms))) # all_targets = np.array(list(chain(*timesteps_targets))) # target_norms = np.linalg.norm(all_targets, axis=1) # print("Target norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(target_norms), np.max(target_norms), np.min(target_norms))) # Color is based on timestep loss cmap = cm.get_cmap('bwr') values = np.concatenate(timesteps_loss) vmin = np.percentile(values, 5) vmax = np.percentile(values, 95) scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap) streamlines = [] colors = [] for i, t, p, l in zip(timesteps_inputs, timesteps_targets, timesteps_prediction, timesteps_loss): pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]] streamline = np.zeros(((len(pts) - 1) * 3 + 1, 3)) streamline[::3] = pts streamline[1:-1:3] = pts[:-1] + p streamline[2:-1:3] = pts[:-1] streamlines.append(streamline) # Color input streamlines in a uniform color, then color predictions based on L2 error color = np.zeros_like(streamline) # Base color of streamlines is minimum value (best score) color[:] = scalar_map.to_rgba(vmin, bytes=True)[:3] color[1:-1:3, :] = scalar_map.to_rgba(l, bytes=True)[:, :3] colors.append(color) tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors}) return tractogram
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel): trk_dir = os.path.join(data_dir, "bundles") if source in ["wm", "trk"]: anat_path = os.path.join(data_dir, "masks", "wm.nii.gz") resized_path = os.path.join(data_dir, "masks", "wm_{}.nii.gz".format(voxel)) elif source == "brain": anat_path = os.path.join("subjects", "ismrm_gt", "dwi_brain_mask.nii.gz") resized_path = os.path.join("subjects", "ismrm_gt", "dwi_brain_mask_125.nii.gz") sp.call([ "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path, resized_path ]) if source == "trk": print("Running Tractconverter...") sp.call([ "python", "tractconverter/scripts/WalkingTractConverter.py", "-i", trk_dir, "-a", resized_path, "-vtk2trk" ]) print("Loading seed bundles...") seed_bundles = [] for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir, "*.trk"))): trk_file = nib.streamlines.load(trk_path) endpoints = [] for fiber in trk_file.tractogram.streamlines: endpoints.append(fiber[0]) endpoints.append(fiber[-1]) seed_bundles.append(endpoints) if i == 0: header = trk_file.header n_seeds = sum([len(b) for b in seed_bundles]) n_bundles = len(seed_bundles) print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles)) seeds = np.array([[seed] for bundle in seed_bundles for seed in bundle]) if keep < 1: if weighted: p = np.zeros(n_seeds) offset = 0 for b in seed_bundles: l = len(b) p[offset:offset + l] = 1 / (l * n_bundles) offset += l else: p = np.ones(n_seeds) / n_seeds elif source in ["brain", "wm"]: weighted = False wm_file = nib.load(resized_path) wm_img = wm_file.get_fdata() seeds = np.argwhere(wm_img > threshold) seeds = np.hstack([seeds, np.ones([len(seeds), 1])]) seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3) n_seeds = len(seeds) if keep < 1: p = np.ones(n_seeds) / n_seeds header = TrkFile.create_empty_header() header["voxel_to_rasmm"] = wm_file.affine header["dimensions"] = wm_file.header["dim"][1:4] header["voxel_sizes"] = wm_file.header["pixdim"][1:4] header["voxel_order"] = get_reference_info(wm_file)[3] if keep < 1: keep_n = int(keep * n_seeds) print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n)) np.random.seed(42) keep_idx = np.random.choice(len(seeds), size=keep_n, replace=False, p=p) seeds = seeds[keep_idx] tractogram = Tractogram(streamlines=ArraySequence(seeds), affine_to_rasmm=np.eye(4)) save_dir = os.path.join(data_dir, "seeds") if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk") save_path = save_path.format( source, "W" + str(int(100 * keep)) if weighted else "all", voxel) print("Saving {}".format(save_path)) TrkFile(tractogram, header).save(save_path) os.remove(resized_path) for file in glob.glob(os.path.join(trk_dir, "*.trk")): os.remove(file)
def evaluation_tractogram(hyperparams, model, dataset, batch_size_override, metric): loss = loss_factory(hyperparams, model, dataset, loss_type=None) batch_scheduler = batch_scheduler_factory(hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False) _ = loss.losses # Hack to generate update dict in loss :( if hyperparams['model'] == 'ffnn_regression': timestep_losses, inputs, targets = log_variables(batch_scheduler, model, loss.loss_per_time_step, dataset.symb_inputs * 1, dataset.symb_targets * 1) # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented timesteps_loss = ArraySequence() seq_loss = [] timesteps_inputs = ArraySequence() timesteps_targets = ArraySequence() idx = 0 for length in dataset.streamlines._lengths: start = idx idx = end = idx+length timesteps_loss.extend(timestep_losses[start:end]) seq_loss.extend(np.mean(timestep_losses[start:end])) timesteps_inputs.extend(inputs[start:end]) timesteps_targets.extend(targets[start:end]) else: timestep_losses, seq_losses, inputs, targets, masks = log_variables(batch_scheduler, model, loss.loss_per_time_step, loss.loss_per_seq, dataset.symb_inputs * 1, dataset.symb_targets * 1, dataset.symb_mask * 1) timesteps_loss = ArraySequence([l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks))]) seq_loss = np.array(list(chain(*seq_losses))) timesteps_inputs = ArraySequence([i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))]) # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension timesteps_targets = ArraySequence([np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks))]) if metric == 'sequence': # Color is based on sequence loss values = seq_loss elif metric == 'timestep' or metric == 'cumul_avg': # Color is based on timestep loss values = np.concatenate(timesteps_loss) else: raise ValueError("Unrecognized metric: {}".format(metric)) cmap = cm.get_cmap('bwr') vmin = np.percentile(values, 5) vmax = np.percentile(values, 95) scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap) streamlines = [] colors = [] for i, t, l, seq_l in zip(timesteps_inputs, timesteps_targets, timesteps_loss, seq_loss): pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]] color = np.zeros_like(pts) if metric == 'sequence': # Streamline color is based on sequence loss color[:, :] = scalar_map.to_rgba(seq_l, bytes=True)[:3] elif metric == 'timestep': # Streamline color is based on timestep loss # Identify first point with green color[0, :] = [0, 255, 0] color[1:, :] = scalar_map.to_rgba(l, bytes=True)[:, :3] elif metric == 'cumul_avg': # Streamline color is based on timestep loss # Compute cumulative average cumul_avg = np.cumsum(l) / np.arange(1, len(l) + 1) # Identify first point with green color[0, :] = [0, 255, 0] color[1:, :] = scalar_map.to_rgba(cumul_avg, bytes=True)[:, :3] else: raise ValueError("Unrecognized metric: {}".format(metric)) streamlines.append(pts) colors.append(color) tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors}) return tractogram