def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_bundle, args.in_centroid], optional=args.reference) assert_outputs_exist(parser, args, args.out_map) sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle) sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) if not len(sft_bundle.streamlines): logging.error('Empty bundle file {}. ' 'Skipping'.format(args.in_bundle)) raise ValueError if not len(sft_centroid.streamlines): logging.error('Centroid file {} should contain one streamline. ' 'Skipping'.format(args.in_centroid)) raise ValueError sft_bundle.to_vox() bundle_streamlines_vox = sft_bundle.streamlines bundle_streamlines_vox._data *= args.upsample sft_centroid.to_vox() centroid_streamlines_vox = sft_centroid.streamlines centroid_streamlines_vox._data *= args.upsample upsampled_shape = [s * args.upsample for s in sft_bundle.dimensions] tdi_mask = compute_tract_counts_map(bundle_streamlines_vox, upsampled_shape) > 0 tdi_mask_nzr = np.nonzero(tdi_mask) tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr) min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind, centroid_streamlines_vox[0]) # Save the (upscaled) labels mask labels_mask = np.zeros(tdi_mask.shape) labels_mask[tdi_mask_nzr] = min_dist_ind + 1 # 0 is background value rescaled_affine = sft_bundle.affine rescaled_affine[:3, :3] /= args.upsample labels_img = nib.Nifti1Image(labels_mask, rescaled_affine) upsampled_spacing = sft_bundle.voxel_sizes / args.upsample labels_img.header.set_zooms(upsampled_spacing) nib.save(labels_img, args.out_map)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_bundle, args.in_centroid]) assert_outputs_exist(parser, args, [args.output_label, args.output_distance]) is_header_compatible(args.in_bundle, args.in_centroid) sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle) sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) if not len(sft_bundle.streamlines): logging.error('Empty bundle file {}. Skipping'.format(args.in_bundle)) raise ValueError if not len(sft_centroid.streamlines): logging.error('Empty centroid streamline file {}. Skipping'.format( args.centroid_streamline)) raise ValueError min_dist_label, min_dist = min_dist_to_centroid( sft_bundle.streamlines.data, sft_centroid.streamlines.data) min_dist_label += 1 # Save assignment in a compressed numpy file # You can load this file and access its data using # f = np.load('someFile.npz') # assignment = f['arr_0'] np.savez_compressed(args.output_label, min_dist_label) # Save distance in a compressed numpy file # You can load this file and access its data using # f = np.load('someFile.npz') # distance = f['arr_0'] np.savez_compressed(args.output_distance, min_dist)
def main(): parser = _build_arg_parser() args = parser.parse_args() set_sft_logger_level('ERROR') assert_inputs_exist(parser, [args.in_bundle, args.in_centroid], optional=args.reference) assert_outputs_exist(parser, args, args.out_labels_map, optional=[ args.out_labels_npz, args.out_distances_npz, args.labels_color_dpp, args.distances_color_dpp ]) sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle) sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) if not len(sft_bundle.streamlines): logging.error('Empty bundle file {}. ' 'Skipping'.format(args.in_bundle)) raise ValueError if len(sft_centroid.streamlines) < 1 \ or len(sft_centroid.streamlines) > 1: logging.error('Centroid file {} should contain one streamline. ' 'Skipping'.format(args.in_centroid)) raise ValueError if not is_header_compatible(sft_centroid, sft_bundle): raise IOError('{} and {}do not have a compatible header'.format( args.in_centroid, args.in_bundle)) sft_bundle.to_vox() sft_bundle.to_corner() # Slightly cut the bundle at the edgge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' binary_bundle = compute_tract_counts_map( sft_bundle.streamlines, sft_bundle.dimensions).astype(bool) structure = ndi.generate_binary_structure(3, 1) if np.count_nonzero(binary_bundle) > args.min_voxel_count \ and len(sft_bundle) > args.min_streamline_count: binary_bundle = ndi.binary_dilation(binary_bundle, structure=np.ones((3, 3, 3))) binary_bundle = ndi.binary_erosion(binary_bundle, structure=structure, iterations=2) bundle_disjoint, _ = ndi.label(binary_bundle) unique, count = np.unique(bundle_disjoint, return_counts=True) val = unique[np.argmax(count[1:]) + 1] binary_bundle[bundle_disjoint != val] = 0 # Chop off some streamlines cut_sft = cut_outside_of_mask_streamlines(sft_bundle, binary_bundle) else: cut_sft = sft_bundle if args.nb_pts is not None: sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) else: args.nb_pts = len(sft_centroid.streamlines[0]) # Generate a centroids labels mask for the centroid alone sft_centroid.to_vox() sft_centroid.to_corner() sft_centroid = _affine_slr(sft_bundle, sft_centroid) # Map every streamlines points to the centroids binary_centroid = compute_tract_counts_map( sft_centroid.streamlines, sft_centroid.dimensions).astype(bool) # TODO N^2 growth in RAM, should split it if we want to do nb_pts = 100 min_dist_label, min_dist = min_dist_to_centroid( cut_sft.streamlines._data, sft_centroid.streamlines._data) min_dist_label += 1 # 0 means no labels # It is not allowed that labels jumps labels for consistency # Streamlines should have continous labels curr_ind = 0 final_streamlines = [] final_label = [] final_dist = [] for i, streamline in enumerate(cut_sft.streamlines): next_ind = curr_ind + len(streamline) curr_labels = min_dist_label[curr_ind:next_ind] curr_dist = min_dist[curr_ind:next_ind] curr_ind = next_ind # Flip streamlines so the labels increase (facilitate if/else) # Should always be ordered in nextflow pipeline gradient = np.gradient(curr_labels) if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): streamline = streamline[::-1] curr_labels = curr_labels[::-1] curr_dist = curr_dist[::-1] # Find jumps, cut them and find the longest gradient = np.ediff1d(curr_labels) max_jump = max(args.nb_pts // 5, 1) if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 split_chunk = np.split(curr_labels, pos_jump) max_len = 0 max_pos = 0 for j, chunk in enumerate(split_chunk): if len(chunk) > max_len: max_len = len(chunk) max_pos = j curr_labels = split_chunk[max_pos] gradient_chunk = np.ediff1d(chunk) if len(np.unique(np.sign(gradient_chunk))) > 1: continue streamline = np.split(streamline, pos_jump)[max_pos] curr_dist = np.split(curr_dist, pos_jump)[max_pos] final_streamlines.append(streamline) final_label.append(curr_labels) final_dist.append(curr_dist) # Re-arrange the new cut streamlines and their metadata # Compute the voxels equivalent of the labels maps new_sft = StatefulTractogram.from_sft(final_streamlines, sft_bundle) tdi_mask_nzr = np.nonzero(binary_bundle) tdi_mask_nzr_ind = np.transpose(tdi_mask_nzr) min_dist_ind, _ = min_dist_to_centroid(tdi_mask_nzr_ind, sft_centroid.streamlines[0]) img_labels = np.zeros(binary_centroid.shape, dtype=np.int16) img_labels[tdi_mask_nzr] = min_dist_ind + 1 # 0 is background value nib.save(nib.Nifti1Image(img_labels, sft_bundle.affine), args.out_labels_map) if args.labels_color_dpp or args.distances_color_dpp \ or args.out_labels_npz or args.out_distances_npz: labels_array = ArraySequence(final_label) dist_array = ArraySequence(final_dist) # WARNING: WILL NOT WORK WITH THE INPUT TRK ! # These will fit only with the TRK saved below. if args.out_labels_npz: np.savez_compressed(args.out_labels_npz, labels_array._data) if args.out_distances_npz: np.savez_compressed(args.out_distances_npz, dist_array._data) cmap = plt.get_cmap(args.colormap) new_sft.data_per_point['color'] = ArraySequence(new_sft.streamlines) # Nicer visualisation for MI-Brain if args.labels_color_dpp: new_sft.data_per_point['color']._data = cmap( labels_array._data / np.max(labels_array._data))[:, 0:3] * 255 save_tractogram(new_sft, args.labels_color_dpp) if args.distances_color_dpp: new_sft.data_per_point['color']._data = cmap( dist_array._data / np.max(dist_array._data))[:, 0:3] * 255 save_tractogram(new_sft, args.distances_color_dpp)