def save_babel(dwi_data, dwi_header, b0_data, b0_header, bval_path, bvec_path, out_path, affine=None, flip=None, swap=None): """ Save a loaded fdf file to nifti. Parameters ---------- out_path: Path of the nifti file to be saved data: Raw data to be saved raw_header: Raw header from fdf files bval_path: Path to the bval file to be saved bvec_path: Path to the bvec file to be saved affine: Affine transformation to save with the data Return ------ None """ nifti1_dwi_header = format_raw_header(dwi_header) nifti1_b0_header = format_raw_header(b0_header) if not is_header_compatible(nifti1_dwi_header, nifti1_b0_header): raise Exception("Images are not of the same resolution/affine") nifti1_header = nifti1_dwi_header if 'orientation' in nifti1_header: orientation = np.identity(4) orientation[:3, :3] = nifti1_header['orientation'].reshape(3, 3) affine = np.linalg.inv(orientation) write_gradient_information(dwi_header, b0_header, bval_path, bvec_path, flip, swap) data = np.concatenate([b0_data[:, :, :, np.newaxis], dwi_data], axis=3) nifti1_header.set_data_shape(data.shape) img = nib.nifti1.Nifti1Image(dataobj=data, header=nifti1_header, affine=affine) vox_dim = [round(num, 3) for num in dwi_header['voxel_dim'][0:4]] img.header.set_zooms(vox_dim) qform = img.header.get_qform() qform[:2, :3] *= -1. if 'origin' in nifti1_header: qform[:len(nifti1_header['origin']), 3] = -nifti1_header['origin'] img.get_header().set_qform(qform) img.update_header() img.to_filename(out_path)
def are_compatible(sft_1, sft_2): """ Compatibility verification of two StatefulTractogram to ensure space, origin, data_per_point and data_per_streamline consistency """ are_sft_compatible = True if not is_header_compatible(sft_1, sft_2): logger.warning('Inconsistent spatial attributes between both sft.') are_sft_compatible = False if sft_1.space != sft_2.space: logger.warning('Inconsistent space between both sft.') are_sft_compatible = False if sft_1.origin != sft_2.origin: logger.warning('Inconsistent origin between both sft.') are_sft_compatible = False if sft_1.get_data_per_point_keys() != sft_2.get_data_per_point_keys(): logger.warning( 'Inconsistent data_per_point between both sft.') are_sft_compatible = False if sft_1.get_data_per_streamline_keys() != \ sft_2.get_data_per_streamline_keys(): logger.warning( 'Inconsistent data_per_streamline between both sft.') are_sft_compatible = False return are_sft_compatible
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) output_streamlines_filename = '{}streamlines.trk'.format( args.output_prefix) output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix) assert_outputs_exist(parser, args, [output_voxels_filename, output_streamlines_filename]) if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1: parser.error('Ratios must be between 0 and 1.') fusion_streamlines = [] if args.reference: reference_file = args.reference else: reference_file = args.in_bundles[0] sft_list = [] for name in args.in_bundles: tmp_sft = load_tractogram_with_reference(parser, args, name) tmp_sft.to_vox() tmp_sft.to_corner() if not is_header_compatible(reference_file, tmp_sft): raise ValueError('Headers are not compatible.') sft_list.append(tmp_sft) fusion_streamlines.append(tmp_sft.streamlines) fusion_streamlines, _ = union_robust(fusion_streamlines) transformation, dimensions, _, _ = get_reference_info(reference_file) volume = np.zeros(dimensions) streamlines_vote = dok_matrix( (len(fusion_streamlines), len(args.in_bundles))) for i in range(len(args.in_bundles)): sft = sft_list[i] binary = compute_tract_counts_map(sft.streamlines, dimensions) volume[binary > 0] += 1 if args.same_tractogram: _, indices = intersection_robust( [fusion_streamlines, sft.streamlines]) streamlines_vote[list(indices), [i]] += 1 if args.same_tractogram: real_indices = [] ratio_value = int(args.ratio_streamlines * len(args.in_bundles)) real_indices = np.where( np.sum(streamlines_vote, axis=1) >= ratio_value)[0] new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices], sft_list[0]) save_tractogram(new_sft, output_streamlines_filename) volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0 volume[volume > 0] = 1 nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation), output_voxels_filename)
def load_node_nifti(directory, in_label, out_label, ref_img): in_filename = os.path.join(directory, '{}_{}.nii.gz'.format(in_label, out_label)) if os.path.isfile(in_filename): if not is_header_compatible(in_filename, ref_img): raise IOError('{} do not have a compatible header'.format( in_filename)) return nib.load(in_filename).get_fdata(dtype=np.float64) return None
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_bundle, args.in_centroid]) assert_outputs_exist(parser, args, [args.output_label, args.output_distance]) is_header_compatible(args.in_bundle, args.in_centroid) sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle) sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) if not len(sft_bundle.streamlines): logging.error('Empty bundle file {}. Skipping'.format(args.in_bundle)) raise ValueError if not len(sft_centroid.streamlines): logging.error('Empty centroid streamline file {}. Skipping'.format( args.centroid_streamline)) raise ValueError min_dist_label, min_dist = min_dist_to_centroid( sft_bundle.streamlines.data, sft_centroid.streamlines.data) min_dist_label += 1 # Save assignment in a compressed numpy file # You can load this file and access its data using # f = np.load('someFile.npz') # assignment = f['arr_0'] np.savez_compressed(args.output_label, min_dist_label) # Save distance in a compressed numpy file # You can load this file and access its data using # f = np.load('someFile.npz') # distance = f['arr_0'] np.savez_compressed(args.output_distance, min_dist)
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_inputs_exist(parser, [args.in_tractogram, args.in_mask]) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) if args.step_size is not None: sft = resample_streamlines_step_size(sft, args.step_size) mask_img = nib.load(args.in_mask) binary_mask = get_data_as_mask(mask_img) if not is_header_compatible(sft, mask_img): parser.error('Incompatible header between the tractogram and mask.') bundle_disjoint, _ = ndi.label(binary_mask) unique, count = np.unique(bundle_disjoint, return_counts=True) if args.biggest_blob: val = unique[np.argmax(count[1:]) + 1] binary_mask[bundle_disjoint != val] = 0 unique = [0, val] if len(unique) == 2: logging.info('The provided mask has 1 entity ' 'cut_outside_of_mask_streamlines function selected.') new_sft = cut_outside_of_mask_streamlines(sft, binary_mask) elif len(unique) == 3: logging.info('The provided mask has 2 entity ' 'cut_between_masks_streamlines function selected.') new_sft = cut_between_masks_streamlines(sft, binary_mask) else: logging.error('The provided mask has more than 2 entities. Cannot cut ' 'between >2.') return if len(new_sft) == 0: logging.warning('No streamline intersected the provided mask. ' 'Saving empty tractogram.') elif args.error_rate is not None: compressed_strs = [ compress_streamlines(s, args.error_rate) for s in new_sft.streamlines ] new_sft = StatefulTractogram.from_sft(compressed_strs, sft) save_tractogram(new_sft, args.out_tractogram)
def assert_same_resolution(images): """ Check the resolution of multiple images. Parameters ---------- images : array of string or string List of images or an image. """ if isinstance(images, str): images = [images] if len(images) == 0: raise Exception("Can't check if images are of the same " "resolution/affine. No image has been given") for curr_image in images[1:]: if not is_header_compatible(images[0], curr_image): raise Exception("Images are not of the same resolution/affine")
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_files) all_valid = True for filepath in args.in_files: _, in_extension = split_name_with_nii(filepath) if in_extension not in ['.trk', '.nii', '.nii.gz']: parser.error( '{} does not have a supported extension'.format(filepath)) if not is_header_compatible(args.in_files[0], filepath): print('{} and {} do not have compatible header.'.format( args.in_files[0], filepath)) all_valid = False if all_valid: print('All input files have compatible headers.')
def main(): parser = _build_arg_parser() args = parser.parse_args() if len(args.in_dwis) != len(args.in_bvals) \ or len(args.in_dwis) != len(args.in_bvecs): parser.error('DWI, bvals and bvecs must have the same length') assert_inputs_exist(parser, args.in_dwis + args.in_bvals + args.in_bvecs) assert_outputs_exist(parser, args, [args.out_dwi, args.out_bval, args.out_bvec]) all_bvals = [] all_bvecs = [] total_size = 0 for i in range(len(args.in_dwis)): bvals, bvecs = read_bvals_bvecs(args.in_bvals[i], args.in_bvecs[i]) if len(bvals) != len(bvecs): raise ValueError('Paired bvals and bvecs must have the same size.') total_size += len(bvals) all_bvals.append(bvals) all_bvecs.append(bvecs) all_bvals = np.concatenate(all_bvals) all_bvecs = np.concatenate(all_bvecs) ref_dwi = nib.load(args.in_dwis[0]) all_dwi = np.zeros(ref_dwi.shape[0:3] + (total_size, ), dtype=args.data_type) last_count = ref_dwi.shape[-1] all_dwi[..., 0:last_count] = ref_dwi.get_fdata() for i in range(1, len(args.in_dwis)): curr_dwi = nib.load(args.in_dwis[i]) if not is_header_compatible(curr_dwi, ref_dwi): raise ValueError('All DWI must have the compatible header.') curr_size = curr_dwi.shape[-1] all_dwi[..., last_count:last_count+curr_size] = \ curr_dwi.get_fdata() np.savetxt(args.out_bval, all_bvals, '%d') np.savetxt(args.out_bvec, all_bvecs.T, '%0.15f') nib.save(nib.Nifti1Image(all_dwi, ref_dwi.affine, header=ref_dwi.header), args.out_dwi)
def load_node_nifti(directory, in_label, out_label, ref_filename): in_filename_1 = os.path.join(directory, '{}_{}.nii.gz'.format(in_label, out_label)) in_filename_2 = os.path.join(directory, '{}_{}.nii.gz'.format(out_label, in_label)) in_filename = None if os.path.isfile(in_filename_1): in_filename = in_filename_1 elif os.path.isfile(in_filename_2): in_filename = in_filename_2 if in_filename is not None: if not is_header_compatible(in_filename, ref_filename): logging.error('{} and {} do not have a compatible header'.format( in_filename, ref_filename)) raise IOError return nib.load(in_filename).get_fdata() _, dims, _, _ = get_reference_info(ref_filename) return np.zeros(dims)
def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, parser, args): """ Verifies the compatibility of a reference sft with a list of files. Params ------ ref_sft: StatefulTractogram A tractogram to be used as reference. files_to_verify: List[str] List of files that should be compatible with the reference sft. Files can be either other tractograms or nifti files (ex: masks). parser: argument parser Will raise an error if a file is not compatible. args: Namespace Should contain a args.reference if any file is a .tck. """ save_ref = args.reference for file in files_to_verify: if file is not None: _, ext = os.path.splitext(file) if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']: # Cheating ref because it may send a lot of warning if loading # many trk with ref (reference was maybe added only for some # of these files) if ext == '.trk': args.reference = None else: args.reference = save_ref mask = load_tractogram_with_reference(parser, args, file, bbox_check=False) else: # should be a nifti file. mask = file compatible = is_header_compatible(ref_sft, mask) if not compatible: parser.error( "Reference tractogram incompatible with {}".format(file))
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.moving_tractogram, args.target_file, args.deformation]) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.moving_tractogram, bbox_check=False) deformation = nib.load(args.deformation) deformation_data = np.squeeze(deformation.get_fdata()) if not is_header_compatible(sft, deformation): parser.error('Input tractogram/reference do not have the same spatial ' 'attribute as the deformation field.') # Warning: Apply warp in-place moved_streamlines = warp_streamlines(sft, deformation_data) new_sft = StatefulTractogram(moved_streamlines, args.target_file, Space.RASMM, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) if args.remove_invalid: ori_len = len(new_sft) new_sft.remove_invalid_streamlines() logging.warning('Removed {} invalid streamlines.'.format( ori_len - len(new_sft))) save_tractogram(new_sft, args.out_tractogram) elif args.keep_invalid: if not new_sft.is_bbox_in_vox_valid(): logging.warning('Saving tractogram with invalid streamlines.') save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) else: save_tractogram(new_sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles) assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) if (args.gt_tails and not args.gt_heads) \ or (args.gt_heads and not args.gt_tails): parser.error("Both --gt_heads and --gt_tails are needed.") if args.gt_endpoints and (args.gt_tails or args.gt_heads): parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads") if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads): parser.error( "Either input --gt_endpoints or --gt_heads and --gt_tails.") if args.verbose: logging.basicConfig(level=logging.INFO) _, ext = os.path.splitext(args.in_tractogram) sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) if args.remove_invalid: sft.remove_invalid_streamlines() initial_count = len(sft) logging.info("Verifying compatibility with ground-truth") for gt in args.gt_bundles: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) logging.info("Computing ground-truth masks") gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions, = \ compute_gt_masks(args.gt_bundles, parser, args) # If endpoints without heads/tails are loaded, split them and continue # normally after. Q/C of the output is important if args.gt_endpoints: logging.info("Extracting ground-truth end and tail masks") gt_tails, gt_heads, affine, dimensions = \ extract_tails_heads_from_endpoints( args.gt_endpoints, args.out_dir) else: gt_tails, gt_heads = args.gt_tails, args.gt_heads logging.info("Verifying compatibility with endpoints") for gt in gt_tails + gt_heads: compatible = is_header_compatible(sft, gt) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) # Load the endpoints heads/tails, keep the correct combinations # separately from all the possible combinations tc_filenames = list(zip(gt_tails, gt_heads)) length_dict = {} if args.gt_config: with open(args.gt_config, "r") as json_file: length_dict = json.load(json_file) tc_streamlines_list = [] wpc_streamlines_list = [] fc_streamlines_list = [] nc_streamlines = [] logging.info("Scoring true connections") for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames): # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections( sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict, extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i], args.dilate_endpoints, args.wrong_path_as_separate) nc_streamlines.extend(nc) if len(tc_sft) > 0: save_tractogram(tc_sft, os.path.join( args.out_dir, "{}_{}_tc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(wpc_sft) > 0: save_tractogram(wpc_sft, os.path.join( args.out_dir, "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) tc_streamlines_list.append(tc_sft.streamlines) wpc_streamlines_list.append(wpc_sft.streamlines) fc_streamlines_list.append(fc_sft.streamlines) logging.info("Recognized {} streamlines between {} and {}".format( len(tc_sft.streamlines) + len(wpc_sft.streamlines) + len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2)) # Again keep the keep the correct combinations comb_filename = list( itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2)) # Remove the true connections from all combinations, leaving only # false connections for tc_f in tc_filenames: comb_filename.remove(tc_f) logging.info("Scoring false connections") # Go through all the possible combinations of endpoints masks for i, roi in enumerate(comb_filename): mask_1_filename, mask_2_filename = roi # That would be done here. # Automatically generate filename for Q/C prefix_1 = extract_prefix(mask_1_filename) prefix_2 = extract_prefix(mask_2_filename) _, ext = os.path.splitext(args.in_tractogram) fc_sft, sft = extract_false_connections(sft, mask_1_filename, mask_2_filename, args.dilate_endpoints) if len(fc_sft) > 0: save_tractogram(fc_sft, os.path.join( args.out_dir, "{}_{}_fc{}".format(prefix_1, prefix_2, ext)), bbox_valid_check=False) logging.info("Recognized {} streamlines between {} and {}".format( len(fc_sft.streamlines), prefix_1, prefix_2)) fc_streamlines_list.append(fc_sft.streamlines) nc_streamlines.extend(sft.streamlines) final_results = {} no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft) save_tractogram(no_conn_sft, os.path.join(args.out_dir, "nc{}".format(ext)), bbox_valid_check=False) # Total number of streamlines for each category # and statistic that are not "bundle-wise" tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list))) fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list))) if args.wrong_path_as_separate: wpc_streamlines_count = len( list(itertools.chain(*wpc_streamlines_list))) else: wpc_streamlines_count = 0 nc_streamlines_count = len(nc_streamlines) total_count = tc_streamlines_count + fc_streamlines_count + \ wpc_streamlines_count + nc_streamlines_count assert total_count == initial_count final_results["tractogram_filename"] = str(args.in_tractogram) final_results["tractogram_overlap"] = 0.0 final_results["tc_streamlines"] = tc_streamlines_count final_results["fc_streamlines"] = fc_streamlines_count final_results["nc_streamlines"] = nc_streamlines_count final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x]) final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x]) final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count if args.wrong_path_as_separate: final_results["wpc_streamlines"] = wpc_streamlines_count final_results["wpc_streamlines_ratio"] = \ wpc_streamlines_count / total_count final_results["wpc_bundle"] = len( [x for x in wpc_streamlines_list if x]) final_results["total_streamlines"] = total_count final_results["bundle_wise"] = {} final_results["bundle_wise"]["true_connections"] = {} final_results["bundle_wise"]["false_connections"] = {} tractogram_overlap = 0.0 for i, filename in enumerate(tc_filenames): current_tc_streamlines = tc_streamlines_list[i] current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps( current_tc_streamlines, sft) if args.wrong_path_as_separate: current_wpc_streamlines = wpc_streamlines_list[i] current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines, sft) tmp_dict = {} tmp_dict["tc_streamlines"] = len(current_tc_streamlines) tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i], current_tc_voxels)[0] bundle_overlap = gt_bundle_masks[i] * current_tc_voxels bundle_overreach = np.zeros(dimensions) bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_voxels >= 1))] = 1 bundle_lacking = np.zeros(dimensions) bundle_lacking[np.where((gt_bundle_masks[i] == 1) & (current_tc_voxels == 0))] = 1 if args.wrong_path_as_separate: tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines) tmp_dict["wpc_dice"] = \ compute_dice_voxel(gt_bundle_masks[i], current_wpc_voxels)[0] # Add wrong path to overreach bundle_overreach[np.where((gt_bundle_masks[i] == 0) & (current_wpc_voxels >= 1))] = 1 tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap) tmp_dict["tc_bundle_overreach"] = \ np.count_nonzero(bundle_overreach) tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking) tmp_dict["tc_bundle_overlap_PCT"] = \ tmp_dict["tc_bundle_overlap"] / \ (tmp_dict["tc_bundle_overlap"] + tmp_dict["tc_bundle_lacking"]) tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"] endpoints_overlap = \ gt_bundle_masks[i] * current_tc_endpoints_voxels endpoints_overreach = np.zeros(dimensions) endpoints_overreach[np.where((gt_bundle_masks[i] == 0) & (current_tc_endpoints_voxels >= 1))] = 1 tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap) tmp_dict["tc_endpoints_overreach"] = np.count_nonzero( endpoints_overreach) final_results["bundle_wise"]["true_connections"][str(filename)] = \ tmp_dict # Bundle-wise statistics, useful for more complex phantom for i, filename in enumerate(comb_filename): current_fc_streamlines = fc_streamlines_list[i] current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft) tmp_dict = {} if len(current_fc_streamlines): tmp_dict["fc_streamlines"] = len(current_fc_streamlines) tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels) final_results["bundle_wise"]["false_connections"][str(filename)] =\ tmp_dict final_results["tractogram_overlap"] = \ tractogram_overlap / len(gt_bundle_masks) with open(os.path.join(args.out_dir, "results.json"), "w") as f: json.dump(final_results, f, indent=args.indent, sort_keys=args.sort_keys)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram, args.in_labels], args.reference) assert_outputs_exist(parser, args, args.out_hdf5) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): os.remove(args.out_hdf5) if (args.save_raw_connections or args.save_intermediate or args.save_discarded) and not args.out_dir: parser.error('To save outputs in the streamlines form, provide the ' 'output directory using --out_dir.') if args.out_dir: if os.path.abspath(args.out_dir) == os.getcwd(): parser.error('Do not use the current path as output directory.') assert_output_dirs_exist_and_empty(parser, args, args.out_dir, create_dir=True) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) set_sft_logger_level('WARNING') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) real_labels = np.unique(data_labels)[1:] if args.out_labels_list: np.savetxt(args.out_labels_list, real_labels, fmt='%i') # Voxel size must be isotropic, for speed/performance considerations vox_sizes = img_labels.header.get_zooms() if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03): parser.error('Labels must be isotropic') logging.info('*** Loading streamlines ***') time1 = time.time() sft = load_tractogram_with_reference(parser, args, args.in_tractogram, bbox_check=False) sft.remove_invalid_streamlines() time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) if not is_header_compatible(sft, img_labels): raise IOError('{} and {}do not have a compatible header'.format( args.in_tractogram, args.in_labels)) sft.to_vox() sft.to_corner() # Get all streamlines intersection indices logging.info('*** Computing streamlines intersection ***') time1 = time.time() indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True) time2 = time.time() logging.info(' Streamlines intersection took {} sec.'.format( round(time2 - time1, 2))) # Compute the connectivity mapping logging.info('*** Computing connectivity information ***') time1 = time.time() con_info = compute_connectivity(indices, data_labels, real_labels, extract_longest_segments_from_profile) time2 = time.time() logging.info(' Connectivity computation took {} sec.'.format( round(time2 - time1, 2))) # Prepare directories and information needed to save. _create_required_output_dirs(args) logging.info('*** Starting connection post-processing and saving. ***') logging.info(' This can be long, be patient.') time1 = time.time() # Saving will be done from streamlines already in the right space comb_list = list(itertools.combinations(real_labels, r=2)) comb_list.extend(zip(real_labels, real_labels)) iteration_counter = 0 with h5py.File(args.out_hdf5, 'w') as hdf5_file: affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft) hdf5_file.attrs['affine'] = affine hdf5_file.attrs['dimensions'] = dimensions hdf5_file.attrs['voxel_sizes'] = voxel_sizes hdf5_file.attrs['voxel_order'] = voxel_order # Each connections is processed independently. Multiprocessing would be # a burden on the I/O of most SSD/HD for in_label, out_label in comb_list: if iteration_counter > 0 and iteration_counter % 100 == 0: logging.info('Split {} nodes out of {}'.format( iteration_counter, len(comb_list))) iteration_counter += 1 pair_info = [] if in_label not in con_info: continue elif out_label in con_info[in_label]: pair_info.extend(con_info[in_label][out_label]) if out_label not in con_info: continue elif in_label in con_info[out_label]: pair_info.extend(con_info[out_label][in_label]) if not len(pair_info): continue connecting_streamlines = [] connecting_ids = [] for connection in pair_info: strl_idx = connection['strl_idx'] curr_streamlines = compute_streamline_segment( sft.streamlines[strl_idx], indices[strl_idx], connection['in_idx'], connection['out_idx'], points_to_idx[strl_idx]) connecting_streamlines.append(curr_streamlines) connecting_ids.append(strl_idx) # Each step is processed from the previous 'success' # 1. raw -> length pass/fail # 2. length pass -> loops pass/fail # 3. loops pass -> outlier detection pass/fail # 4. outlier detection pass -> qb curvature pass/fail # 5. qb curvature pass == final connections connecting_streamlines = ArraySequence(connecting_streamlines) raw_dps = sft.data_per_streamline[connecting_ids] raw_sft = StatefulTractogram.from_sft(connecting_streamlines, sft, data_per_streamline=raw_dps, data_per_point={}) _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label, out_label) # Doing all post-processing if not args.no_pruning: valid_length_ids, invalid_length_ids = _prune_segments( raw_sft.streamlines, args.min_length, args.max_length, vox_sizes[0]) invalid_length_sft = raw_sft[invalid_length_ids] valid_length = connecting_streamlines[valid_length_ids] _save_if_needed(invalid_length_sft, hdf5_file, args, 'discarded', 'invalid_length', in_label, out_label) else: valid_length = connecting_streamlines valid_length_ids = range(len(connecting_streamlines)) if not len(valid_length): continue valid_length_sft = raw_sft[valid_length_ids] _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate', 'valid_length', in_label, out_label) if not args.no_remove_loops: no_loop_ids = remove_loops_and_sharp_turns( valid_length, args.loop_max_angle) loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids) loops_sft = valid_length_sft[loop_ids] no_loops = valid_length[no_loop_ids] _save_if_needed(loops_sft, hdf5_file, args, 'discarded', 'loops', in_label, out_label) else: no_loops = valid_length no_loop_ids = range(len(valid_length)) if not len(no_loops): continue no_loops_sft = valid_length_sft[no_loop_ids] _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate', 'no_loops', in_label, out_label) if not args.no_remove_outliers: outliers_ids, inliers_ids = remove_outliers( no_loops, args.outlier_threshold, nb_samplings=10, fast_approx=True) outliers_sft = no_loops_sft[outliers_ids] inliers = no_loops[inliers_ids] _save_if_needed(outliers_sft, hdf5_file, args, 'discarded', 'outliers', in_label, out_label) else: inliers = no_loops inliers_ids = range(len(no_loops)) if not len(inliers): continue inliers_sft = no_loops_sft[inliers_ids] _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate', 'inliers', in_label, out_label) if not args.no_remove_curv_dev: no_qb_curv_ids = remove_loops_and_sharp_turns( inliers, args.loop_max_angle, use_qb=True, qb_threshold=args.curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(inliers)), no_qb_curv_ids) qb_curv_sft = inliers_sft[qb_curv_ids] _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded', 'qb_curv', in_label, out_label) else: no_qb_curv_ids = range(len(inliers)) no_qb_curv_sft = inliers_sft[no_qb_curv_ids] _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final', in_label, out_label) time2 = time.time() logging.info( ' Connections post-processing and saving took {} sec.'.format( round(time2 - time1, 2)))
def concatenate_sft(sft_list, erase_metadata=False, metadata_fake_init=False): """ Concatenate a list of StatefulTractogram together """ if erase_metadata: sft_list[0].data_per_point = {} sft_list[0].data_per_streamline = {} for sft in sft_list[1:]: if erase_metadata: sft.data_per_point = {} sft.data_per_streamline = {} elif metadata_fake_init: for dps_key in list(sft.data_per_streamline.keys()): if dps_key not in sft_list[0].data_per_streamline.keys(): del sft.data_per_streamline[dps_key] for dpp_key in list(sft.data_per_point.keys()): if dpp_key not in sft_list[0].data_per_point.keys(): del sft.data_per_point[dpp_key] for dps_key in sft_list[0].data_per_streamline.keys(): if dps_key not in sft.data_per_streamline: arr_shape = sft_list[0].data_per_streamline[dps_key].shape arr_shape[0] = len(sft) sft.data_per_streamline[dps_key] = np.zeros(arr_shape) for dpp_key in sft_list[0].data_per_point.keys(): if dpp_key not in sft.data_per_point: arr_seq = ArraySequence() arr_seq_shape = list( sft_list[0].data_per_point[dpp_key]._data.shape) arr_seq_shape[0] = len(sft.streamlines._data) arr_seq._data = np.zeros(arr_seq_shape) arr_seq._offsets = sft.streamlines._offsets arr_seq._lengths = sft.streamlines._lengths sft.data_per_point[dpp_key] = arr_seq if not metadata_fake_init and \ not StatefulTractogram.are_compatible(sft, sft_list[0]): raise ValueError('Incompatible SFT, check space attributes and ' 'data_per_point/streamlines.') elif not is_header_compatible(sft, sft_list[0]): raise ValueError('Incompatible SFT, check space attributes.') total_streamlines = 0 total_points = 0 lengths = [] for sft in sft_list: total_streamlines += len(sft.streamlines._offsets) total_points += len(sft.streamlines._data) lengths.extend(sft.streamlines._lengths) lengths = np.array(lengths, dtype=np.uint32) offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))).astype(np.uint64) dpp = {} for dpp_key in sft_list[0].data_per_point.keys(): arr_seq_shape = list(sft_list[0].data_per_point[dpp_key]._data.shape) arr_seq_shape[0] = total_points dpp[dpp_key] = ArraySequence() dpp[dpp_key]._data = np.zeros(arr_seq_shape) dpp[dpp_key]._lengths = lengths dpp[dpp_key]._offsets = offsets dps = {} for dps_key in sft_list[0].data_per_streamline.keys(): arr_seq_shape = list(sft_list[0].data_per_streamline[dps_key].shape) arr_seq_shape[0] = total_streamlines dps[dps_key] = np.zeros(arr_seq_shape) streamlines = ArraySequence() streamlines._data = np.zeros((total_points, 3)) streamlines._lengths = lengths streamlines._offsets = offsets pts_counter = 0 strs_counter = 0 for sft in sft_list: pts_curr_len = len(sft.streamlines._data) strs_curr_len = len(sft.streamlines._offsets) if strs_curr_len == 0 or pts_curr_len == 0: continue streamlines._data[pts_counter:pts_counter+pts_curr_len] = \ sft.streamlines._data for dpp_key in sft_list[0].data_per_point.keys(): dpp[dpp_key]._data[pts_counter:pts_counter+pts_curr_len] = \ sft.data_per_point[dpp_key]._data for dps_key in sft_list[0].data_per_streamline.keys(): dps[dps_key][strs_counter:strs_counter+strs_curr_len] = \ sft.data_per_streamline[dps_key] pts_counter += pts_curr_len strs_counter += strs_curr_len fused_sft = StatefulTractogram.from_sft(streamlines, sft_list[0], data_per_point=dpp, data_per_streamline=dps) return fused_sft
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_inputs_exist(parser, args.in_tractograms) assert_outputs_exist(parser, args, args.out_tractogram, optional=args.save_indices) if args.operation == 'lazy_concatenate': logging.info('Using lazy_concatenate, no spatial or metadata related ' 'checks are performed.\nMetadata will be lost, only ' 'trk/tck file are supported.') def list_generator_from_nib(filenames): for in_file in filenames: tractogram_file = nib.streamlines.load(in_file, lazy_load=True) for s in tractogram_file.streamlines: yield s header = None for in_file in args.in_tractograms: _, ext = os.path.splitext(in_file) if ext == '.trk': if header is None: header = nib.streamlines.load( in_file, lazy_load=True).header elif not is_header_compatible(header, in_file): logging.warning('Incompatible headers in the list.') generator = list_generator_from_nib(args.in_tractograms) out_tractogram = LazyTractogram(lambda: generator, affine_to_rasmm=np.eye(4)) nib.streamlines.save(out_tractogram, args.out_tractogram, header=header) return # Load all input streamlines. sft_list = [] for f in args.in_tractograms: sft_list.append(load_tractogram_with_reference( parser, args, f, bbox_check=not args.ignore_invalid)) # Apply the requested operation to each input file. logging.info('Performing operation \'{}\'.'.format(args.operation)) new_sft = concatenate_sft(sft_list, args.no_metadata, args.fake_metadata) if args.operation == 'concatenate': indices = np.arange(len(new_sft), dtype=np.uint32) else: streamlines_list = [sft.streamlines for sft in sft_list] op_name = args.operation if args.robust: op_name += '_robust' _, indices = OPERATIONS[op_name](streamlines_list, precision=args.precision) else: _, indices = perform_streamlines_operation( OPERATIONS[op_name], streamlines_list, precision=args.precision) # Save the indices to a file if requested. if args.save_indices: start = 0 out_dict = {} streamlines_len_cumsum = [len(sft) for sft in sft_list] for name, nb in zip(args.in_tractograms, streamlines_len_cumsum): end = start + nb # Switch to int32 for json out_dict[name] = [int(i - start) for i in indices if start <= i < end] start = end with open(args.save_indices, 'wt') as f: json.dump(out_dict, f, indent=args.indent, sort_keys=args.sort_keys) # Save the new streamlines (and metadata) logging.info('Saving {} streamlines to {}.'.format(len(indices), args.out_tractogram)) save_tractogram(new_sft[indices], args.out_tractogram, bbox_valid_check=not args.ignore_invalid)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_volume) assert_outputs_exist(parser, args, args.out_image) output_names = [ 'axial_superior', 'axial_inferior', 'coronal_posterior', 'coronal_anterior', 'sagittal_left', 'sagittal_right' ] for filename in args.in_bundles: _, ext = os.path.splitext(filename) if ext == '.tck': tractogram = load_tractogram_with_reference(parser, args, filename) else: tractogram = filename if not is_header_compatible(args.in_volume, tractogram): parser.error('{} does not have a compatible header with {}'.format( filename, args.in_volume)) # Delete temporary tractogram else: del tractogram output_dir = os.path.dirname(args.out_image) if output_dir: assert_output_dirs_exist_and_empty(parser, args, output_dir, create_dir=True) _, extension = os.path.splitext(args.out_image) # ----------------------------------------------------------------------- # # Mosaic, column 0: orientation names and data description # ----------------------------------------------------------------------- # width = args.resolution_of_thumbnails height = args.resolution_of_thumbnails rows = 6 cols = len(args.in_bundles) text_pos_x = 50 text_pos_y = 50 # Creates a new empty image, RGB mode mosaic = Image.new('RGB', ((cols + 1) * width, (rows + 1) * height)) # Prepare draw and font objects to render text draw = ImageDraw.Draw(mosaic) font = get_font(args) # Data of the volume used as background ref_img = nib.load(args.in_volume) data = ref_img.get_fdata(dtype=np.float32) affine = ref_img.affine mean, std = data[data > 0].mean(), data[data > 0].std() value_range = (mean - 0.5 * std, mean + 1.5 * std) # First column with rows description draw_column_with_names(draw, output_names, text_pos_x, text_pos_y, height, font) # ----------------------------------------------------------------------- # # Columns with bundles # ----------------------------------------------------------------------- # random.seed(args.random_coloring) for idx_bundle, bundle_file in enumerate(args.in_bundles): bundle_file_name = os.path.basename(bundle_file) bundle_name, bundle_ext = split_name_with_nii(bundle_file_name) i = (idx_bundle + 1) * width if not os.path.isfile(bundle_file): print('\nInput file {} doesn\'t exist.'.format(bundle_file)) number_streamlines = 0 view_number = 6 j = height * view_number draw_bundle_information(draw, bundle_file_name, number_streamlines, i + text_pos_x, j + text_pos_y, font) else: if args.uniform_coloring: colors = args.uniform_coloring elif args.random_coloring is not None: colors = random_rgb() # Select the streamlines to plot if bundle_ext in ['.tck', '.trk']: if (args.random_coloring is None and args.uniform_coloring is None): colors = None bundle_tractogram_file = nib.streamlines.load(bundle_file) streamlines = bundle_tractogram_file.streamlines bundle_actor = actor.line(streamlines, colors) nbr_of_elem = len(streamlines) # Select the volume to plot elif bundle_ext in ['.nii.gz', '.nii']: if not args.random_coloring and not args.uniform_coloring: colors = [1.0, 1.0, 1.0] bundle_img_file = nib.load(bundle_file) roi = get_data_as_mask(bundle_img_file) bundle_actor = actor.contour_from_roi(roi, bundle_img_file.affine, colors) nbr_of_elem = np.count_nonzero(roi) # Render ren = window.Scene() zoom = args.zoom opacity = args.opacity_background # Structural data slice_actor = actor.slicer(data, affine, value_range) slice_actor.opacity(opacity) ren.add(slice_actor) # Streamlines ren.add(bundle_actor) ren.reset_camera() ren.zoom(zoom) view_number = 0 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.pitch(180) ren.reset_camera() ren.zoom(zoom) view_number = 1 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.rm(slice_actor) slice_actor2 = slice_actor.copy() slice_actor2.display(None, slice_actor2.shape[1] // 2, None) slice_actor2.opacity(opacity) ren.add(slice_actor2) ren.pitch(90) ren.set_camera(view_up=(0, 0, 1)) ren.reset_camera() ren.zoom(zoom) view_number = 2 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.pitch(180) ren.set_camera(view_up=(0, 0, 1)) ren.reset_camera() ren.zoom(zoom) view_number = 3 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.rm(slice_actor2) slice_actor3 = slice_actor.copy() slice_actor3.display(slice_actor3.shape[0] // 2, None, None) slice_actor3.opacity(opacity) ren.add(slice_actor3) ren.yaw(90) ren.reset_camera() ren.zoom(zoom) view_number = 4 set_img_in_cell(mosaic, ren, view_number, width, height, i) ren.yaw(180) ren.reset_camera() ren.zoom(zoom) view_number = 5 set_img_in_cell(mosaic, ren, view_number, width, height, i) view_number = 6 j = height * view_number draw_bundle_information(draw, bundle_file_name, nbr_of_elem, i + text_pos_x, j + text_pos_y, font) # Save image to file mosaic.save(args.out_image)
def main(): parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) output_streamlines_filename = '{}streamlines.trk'.format( args.output_prefix) output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix) assert_outputs_exist(parser, args, [output_voxels_filename, output_streamlines_filename]) if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1: parser.error('Ratios must be between 0 and 1.') fusion_streamlines = [] for name in args.in_bundles: fusion_streamlines.extend( load_tractogram_with_reference(parser, args, name).streamlines) fusion_streamlines, _ = perform_streamlines_operation( union, [fusion_streamlines], 0) fusion_streamlines = ArraySequence(fusion_streamlines) if args.reference: reference_file = args.reference else: reference_file = args.in_bundles[0] transformation, dimensions, _, _ = get_reference_info(reference_file) volume = np.zeros(dimensions) streamlines_vote = dok_matrix( (len(fusion_streamlines), len(args.in_bundles))) for i, name in enumerate(args.in_bundles): if not is_header_compatible(reference_file, name): raise ValueError('Both headers are not the same') sft = load_tractogram_with_reference(parser, args, name) bundle = sft.get_streamlines_copy() sft.to_vox() bundle_vox_space = sft.get_streamlines_copy() binary = compute_tract_counts_map(bundle_vox_space, dimensions) volume[binary > 0] += 1 if args.same_tractogram: _, indices = perform_streamlines_operation( intersection, [fusion_streamlines, bundle], 0) streamlines_vote[list(indices), i] += 1 if args.same_tractogram: real_indices = [] for i in range(len(fusion_streamlines)): ratio_value = int(args.ratio_streamlines * len(args.in_bundles)) if np.sum(streamlines_vote[i]) >= ratio_value: real_indices.append(i) new_streamlines = fusion_streamlines[real_indices] sft = StatefulTractogram(new_streamlines, reference_file, Space.RASMM) save_tractogram(sft, output_streamlines_filename) volume[volume < int(args.ratio_streamlines * len(args.in_bundles))] = 0 volume[volume > 0] = 1 nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation), output_voxels_filename)
def main(): parser = _buildArgsParser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_tractogram]) assert_outputs_exists(parser, args, [args.out_tractogram]) if args.verbose: logging.basicConfig(level=logging.DEBUG) roi_opt_list = prepare_filtering_list(parser, args) sft = load_tractogram_with_reference(parser, args, args.in_tractogram) for i, roi_opt in enumerate(roi_opt_list): # Atlas needs an extra argument (value in the LUT) if roi_opt[0] == 'atlas_roi': filter_type, filter_arg_1, filter_arg_2, \ filter_mode, filter_criteria = roi_opt else: filter_type, filter_arg, filter_mode, filter_criteria = roi_opt is_not = False if filter_criteria == 'include' else True if filter_type == 'drawn_roi': img = nib.load(filter_arg) if not is_header_compatible(img, sft): parser.error('Headers from the tractogram and the mask are ' 'not compatible.') mask = img.get_data() filtered_streamlines, indexes = filter_grid_roi( sft, mask, filter_mode, is_not) elif filter_type == 'atlas_roi': img = nib.load(filter_arg_1) if not is_header_compatible(img, sft): parser.error('Headers from the tractogram and the mask are ' 'not compatible.') atlas = img.get_data().astype(np.uint16) mask = np.zeros(atlas.shape, dtype=np.uint16) mask[atlas == int(filter_arg_2)] = 1 filtered_streamlines, indexes = filter_grid_roi( sft, mask, filter_mode, is_not) # For every case, the input number must be greater or equal to 0 and # below the dimension, since this is a voxel space operation elif filter_type in ['x_plane', 'y_plane', 'z_plane']: filter_arg = int(filter_arg) _, dim, _, _ = sft.space_attribute mask = np.zeros(dim, dtype=np.int16) error_msg = None if filter_type == 'x_plane': if 0 <= filter_arg < dim[0]: mask[filter_arg, :, :] = 1 else: error_msg = 'X plane ' + str(filter_arg) elif filter_type == 'y_plane': if 0 <= filter_arg < dim[1]: mask[:, filter_arg, :] = 1 else: error_msg = 'Y plane ' + str(filter_arg) elif filter_type == 'z_plane': if 0 <= filter_arg < dim[2]: mask[:, :, filter_arg] = 1 else: error_msg = 'Z plane ' + str(filter_arg) if error_msg: parser.error('{} is not valid according to the ' 'tractogram header.'.format(error_msg)) filtered_streamlines, indexes = filter_grid_roi( sft, mask, filter_mode, is_not) elif filter_type == 'bdo': geometry, radius, center = read_info_from_mb_bdo(filter_arg) if geometry == 'Ellipsoid': filtered_streamlines, indexes = filter_ellipsoid( sft, radius, center, filter_mode, is_not) elif geometry == 'Cuboid': filtered_streamlines, indexes = filter_cuboid( sft, radius, center, filter_mode, is_not) logging.debug('The filtering options {0} resulted in ' '{1} streamlines'.format(roi_opt, len(filtered_streamlines))) data_per_streamline = sft.data_per_streamline[indexes] data_per_point = sft.data_per_point[indexes] sft = StatefulTractogram(filtered_streamlines, sft, Space.RASMM, data_per_streamline=data_per_streamline, data_per_point=data_per_point) if not filtered_streamlines: if args.no_empty: logging.debug("The file {} won't be written (0 streamline)".format( args.out_tractogram)) return logging.debug('The file {} contains 0 streamline'.format( args.out_tractogram)) save_tractogram(sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_bundle] + args.in_metrics, optional=args.in_centroid) if args.nb_pts_per_streamline <= 1: parser.error('--nb_pts_per_streamline {} needs to be greater than ' '1'.format(args.nb_pts_per_streamline)) assert_same_resolution(args.in_metrics + [args.in_bundle]) sft = load_tractogram_with_reference(parser, args, args.in_bundle) metrics = [nib.load(m) for m in args.in_metrics] bundle_name, _ = os.path.splitext(os.path.basename(args.in_bundle)) stats = {} if len(sft) == 0: stats[bundle_name] = None print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys)) return # Centroid - will be use as reference to reorient each streamline if args.in_centroid: is_header_compatible(args.in_bundle, args.in_centroid) sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) centroid_streamlines = sft_centroid.streamlines[0] nb_pts_per_streamline = len(centroid_streamlines) else: centroid_streamlines = get_streamlines_centroid( sft.streamlines, args.nb_pts_per_streamline) nb_pts_per_streamline = args.nb_pts_per_streamline resampled_sft = resample_streamlines_num_points(sft, nb_pts_per_streamline) # Make sure all streamlines go in the same direction. We want to make # sure point #1 / args.nb_pts_per_streamline of streamline A is matched # with point #1 / 20 of streamline B and so on num_streamlines = len(resampled_sft) for s in np.arange(num_streamlines): streamline = resampled_sft.streamlines[s] direct = average_euclidean(centroid_streamlines, streamline) flipped = average_euclidean(centroid_streamlines, streamline[::-1]) if flipped < direct: resampled_sft.streamlines[s] = streamline[::-1] profiles = get_bundle_metrics_profiles(resampled_sft, metrics) t_profiles = np.expand_dims(profiles, axis=1) t_profiles = np.rollaxis(t_profiles, 3, 2) stats[bundle_name] = {} for metric, profile, t_profile in zip(metrics, profiles, t_profiles): metric_name, _ = split_name_with_nii( os.path.basename(metric.get_filename())) stats[bundle_name][metric_name] = { 'mean': np.mean(profile, axis=0).tolist(), 'std': np.std(profile, axis=0).tolist(), 'bundleprofile': t_profile.tolist() } print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
def main(): parser = _build_arg_parser() args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) assert_outputs_exist(parser, args, args.out_image) # Binary operations require specific verifications binary_op = [ 'union', 'intersection', 'difference', 'invert', 'dilation', 'erosion', 'closing', 'opening' ] if args.operation not in OPERATIONS.keys(): parser.error('Operation {} not implement.'.format(args.operation)) # Find at least one image for reference for input_arg in args.in_images: if not is_float(input_arg): ref_img = nib.load(input_arg) mask = np.zeros(ref_img.shape) break # Load all input masks. input_data = [] for input_arg in args.in_images: if not is_float(input_arg) and \ not is_header_compatible(ref_img, input_arg): parser.error('Inputs do not have a compatible header.') data = load_data(input_arg) if isinstance(data, np.ndarray) and \ data.dtype != ref_img.get_data_dtype() and \ not args.data_type: parser.error('Inputs do not have a compatible data type.\n' 'Use --data_type to specify output datatype.') if args.operation in binary_op and isinstance(data, np.ndarray): unique = np.unique(data) if not len(unique) <= 2: parser.error('Binary operations can only be performed with ' 'binary masks') if len(unique) == 2 and not (unique == [0, 1]).all(): logging.warning('Input data for binary operation are not ' 'binary arrays, will be converted.\n' 'Non-zeros will be set to ones.') data[data != 0] = 1 if isinstance(data, np.ndarray): data = data.astype(np.float64) mask[data > 0] = 1 input_data.append(data) if args.operation == 'convert' and not args.data_type: parser.error('Convert operation must be used with --data_type.') try: output_data = OPERATIONS[args.operation](input_data) except ValueError: logging.error('{} operation failed.'.format( args.operation.capitalize())) return if args.data_type: output_data = output_data.astype(args.data_type) ref_img.header.set_data_dtype(args.data_type) else: output_data = output_data.astype(ref_img.get_data_dtype()) if args.exclude_background: output_data[mask == 0] = 0 new_img = nib.Nifti1Image(output_data, ref_img.affine, header=ref_img.header) nib.save(new_img, args.out_image)
def main(): parser = _build_arg_parser() args = parser.parse_args() # The number of labels maps must be equal to the number of bundles tmp = args.in_bundles + args.in_labels args.in_labels = args.in_bundles[(len(tmp) // 2):] + args.in_labels args.in_bundles = args.in_bundles[0:len(tmp) // 2] assert_inputs_exist(parser, args.in_bundles + args.in_labels) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.save_rendering) stats = {} num_digits_labels = 3 scene = window.Scene() scene.background(tuple(map(int, args.background))) for i, filename in enumerate(args.in_bundles): sft = load_tractogram_with_reference(parser, args, filename) sft.to_vox() sft.to_corner() img_labels = nib.load(args.in_labels[i]) # same subject: same header or coregistered subjects: same header if not is_header_compatible(sft, args.in_bundles[0]) \ or not is_header_compatible(img_labels, args.in_bundles[0]): parser.error('All headers must be identical.') data_labels = img_labels.get_fdata() bundle_name, _ = os.path.splitext(os.path.basename(filename)) unique_labels = np.unique(data_labels)[1:].astype(int) # Empty bundle should at least return a json if not len(sft): tmp_dict = {} for label in unique_labels: tmp_dict['{}'.format(label).zfill(num_digits_labels)] \ = {'mean': 0.0, 'std': 0.0} stats[bundle_name] = {'diameter': tmp_dict} continue counter = 0 labels_dict = {label: ([], []) for label in unique_labels} pts_labels = map_coordinates(data_labels, sft.streamlines._data.T - 0.5, order=0) # For each label, all positions and directions are needed to get # a tube estimation per label. for streamline in sft.streamlines: direction = np.gradient(streamline, axis=0).tolist() curr_labels = pts_labels[counter:counter + len(streamline)].tolist() for i, label in enumerate(curr_labels): if label > 0: labels_dict[label][0].append(streamline[i]) labels_dict[label][1].append(direction[i]) counter += len(streamline) centroid = np.zeros((len(unique_labels), 3)) radius = np.zeros((len(unique_labels), 1)) error = np.zeros((len(unique_labels), 1)) for key in unique_labels: key = int(key) c, d, e = fit_circle_in_space(labels_dict[key][0], labels_dict[key][1], args.fitting_func) centroid[key - 1], radius[key - 1], error[key - 1] = c, d, e # Spatial smoothing to avoid degenerate estimation centroid_smooth = gaussian_filter(centroid, sigma=[1, 0], mode='nearest') centroid_smooth[::len(centroid) - 1] = centroid[::len(centroid) - 1] radius = gaussian_filter(radius, sigma=1, mode='nearest') error = gaussian_filter(error, sigma=1, mode='nearest') tmp_dict = {} for label in unique_labels: tmp_dict['{}'.format(label).zfill(num_digits_labels)] \ = {'mean': float(radius[label-1])*2, 'std': float(error[label-1])} stats[bundle_name] = {'diameter': tmp_dict} if args.show_rendering or args.save_rendering: tube_actor = create_tube_with_radii( centroid_smooth, radius, error, wireframe=args.wireframe, error_coloring=args.error_coloring) scene.add(tube_actor) cmap = plt.get_cmap('jet') coloring = cmap(pts_labels / np.max(pts_labels))[:, 0:3] streamlines_actor = actor.streamtube(sft.streamlines, linewidth=args.width, opacity=args.opacity, colors=coloring) scene.add(streamlines_actor) slice_actor = actor.slicer(data_labels, np.eye(4)) slice_actor.opacity(0.0) scene.add(slice_actor) # If there's actually streamlines to display if args.show_rendering: showm = window.ShowManager(scene, reset_camera=True) showm.initialize() showm.start() elif args.save_rendering: scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'superior.png'), size=(1920, 1080), offscreen=True) scene.pitch(180) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'inferior.png'), size=(1920, 1080), offscreen=True) scene.pitch(90) scene.set_camera(view_up=(0, 0, 1)) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'posterior.png'), size=(1920, 1080), offscreen=True) scene.pitch(180) scene.set_camera(view_up=(0, 0, 1)) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'anterior.png'), size=(1920, 1080), offscreen=True) scene.yaw(90) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'right.png'), size=(1920, 1080), offscreen=True) scene.yaw(180) scene.reset_camera() snapshot(scene, os.path.join(args.save_rendering, 'left.png'), size=(1920, 1080), offscreen=True) print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
def _processing_wrapper(args): hdf5_filename = args[0] labels_img = args[1] in_label, out_label = args[2] measures_to_compute = copy.copy(args[3]) if args[4] is not None: similarity_directory = args[4][0] weighted = args[5] include_dps = args[6] min_lesion_vol = args[7] hdf5_file = h5py.File(hdf5_filename, 'r') key = '{}_{}'.format(in_label, out_label) if key not in hdf5_file: return streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) if len(streamlines) == 0: return affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) measures_to_return = {} if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)): raise ValueError('Provided hdf5 have incompatible headers.') # Precompute to save one transformation, insert later if 'length' in measures_to_compute: streamlines_copy = list(streamlines) # scil_decompose_connectivity.py requires isotropic voxels mean_length = np.average(length(streamlines_copy))*voxel_sizes[0] # If density is not required, do not compute it # Only required for volume, similarity and any metrics if not ((len(measures_to_compute) == 1 and ('length' in measures_to_compute or 'streamline_count' in measures_to_compute)) or (len(measures_to_compute) == 2 and ('length' in measures_to_compute and 'streamline_count' in measures_to_compute))): density = compute_tract_counts_map(streamlines, dimensions) if 'volume' in measures_to_compute: measures_to_return['volume'] = np.count_nonzero(density) * \ np.prod(voxel_sizes) measures_to_compute.remove('volume') if 'streamline_count' in measures_to_compute: measures_to_return['streamline_count'] = len(streamlines) measures_to_compute.remove('streamline_count') if 'length' in measures_to_compute: measures_to_return['length'] = mean_length measures_to_compute.remove('length') if 'similarity' in measures_to_compute and similarity_directory: density_sim = load_node_nifti(similarity_directory, in_label, out_label, labels_img) if density_sim is None: ba_vox = 0 else: ba_vox = compute_bundle_adjacency_voxel(density, density_sim) measures_to_return['similarity'] = ba_vox measures_to_compute.remove('similarity') for measure in measures_to_compute: # Maps if isinstance(measure, str) and os.path.isdir(measure): map_dirname = measure map_data = load_node_nifti(map_dirname, in_label, out_label, labels_img) measures_to_return[map_dirname] = np.average( map_data[map_data > 0]) elif isinstance(measure, tuple): if not isinstance(measure[0], tuple) \ and os.path.isfile(measure[0]): metric_filename = measure[0] metric_img = measure[1] if not is_header_compatible(metric_img, labels_img): logging.error('{} do not have a compatible header'.format( metric_filename)) raise IOError metric_data = metric_img.get_fdata(dtype=np.float64) if weighted: avg_value = np.average(metric_data, weights=density) else: avg_value = np.average(metric_data[density > 0]) measures_to_return[metric_filename] = avg_value # lesion else: lesion_filename = measure[0][0] computed_lesion_labels = measure[0][1] lesion_img = measure[1] if not is_header_compatible(lesion_img, labels_img): logging.error('{} do not have a compatible header'.format( lesion_filename)) raise IOError voxel_sizes = lesion_img.header.get_zooms()[0:3] lesion_img.set_filename('tmp.nii.gz') lesion_atlas = get_data_as_label(lesion_img) tmp_dict = compute_lesion_stats( density.astype(bool), lesion_atlas, voxel_sizes=voxel_sizes, single_label=True, min_lesion_vol=min_lesion_vol, precomputed_lesion_labels=computed_lesion_labels) tmp_ind = _streamlines_in_mask(list(streamlines), lesion_atlas.astype(np.uint8), np.eye(3), [0, 0, 0]) streamlines_count = len( np.where(tmp_ind == [0, 1][True])[0].tolist()) if tmp_dict: measures_to_return[lesion_filename+'vol'] = \ tmp_dict['lesion_total_volume'] measures_to_return[lesion_filename+'count'] = \ tmp_dict['lesion_count'] measures_to_return[lesion_filename+'sc'] = \ streamlines_count else: measures_to_return[lesion_filename+'vol'] = 0 measures_to_return[lesion_filename+'count'] = 0 measures_to_return[lesion_filename+'sc'] = 0 if include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: out_file = os.path.join(include_dps, dps_key) if 'commit' in dps_key: measures_to_return[out_file] = np.sum( hdf5_file[key][dps_key]) else: measures_to_return[out_file] = np.average( hdf5_file[key][dps_key]) return {(in_label, out_label): measures_to_return}
def main(): parser = _build_arg_parser() args = parser.parse_args() image_files = [] indices_per_volume = [] # Separate argument per volume used_indices_all = False for v_args in args.volume_ids: if len(v_args) < 2: logging.error("No indices was given for a given volume") image_files.append(v_args[0]) if "all" in v_args: used_indices_all = True indices_per_volume.append("all") else: indices_per_volume.append(np.asarray(v_args[1:], dtype=np.int)) if used_indices_all and args.out_labels_ids: logging.error("'all' indices cannot be used with 'out_labels_ids'") # Check inputs / output assert_inputs_exist(parser, image_files) assert_outputs_exist(parser, args, args.output) # Load volume and do checks data_list = [] first_img = nib.load(image_files[0]) for i in range(len(image_files)): # Load images volume_nib = nib.load(image_files[i]) data = np.round(volume_nib.get_data()).astype(np.int) data_list.append(data) assert (is_header_compatible(first_img, image_files[i])) if (isinstance(indices_per_volume[i], str) and indices_per_volume[i] == "all"): indices_per_volume[i] = np.unique(data) filtered_ids_per_vol = [] # Remove background labels for id_list in indices_per_volume: id_list = np.asarray(id_list) new_ids = id_list[~np.in1d(id_list, args.background)] filtered_ids_per_vol.append(new_ids) # Prepare output indices if args.out_labels_ids: out_labels = args.out_labels_ids if len(out_labels) != len(np.hstack(indices_per_volume)): logging.error("--out_labels_ids, requires the same amount" " of total given input indices") elif args.unique: stack = np.hstack(filtered_ids_per_vol) ids = np.arange(len(stack) + 1) out_labels = np.setdiff1d(ids, args.background)[:len(stack)] elif args.group_in_m: m_list = [] for i in range(len(filtered_ids_per_vol)): prefix = i * 1000000 m_list.append(prefix + np.asarray(filtered_ids_per_vol[i])) out_labels = np.hstack(m_list) else: out_labels = np.hstack(filtered_ids_per_vol) if len(np.unique(out_labels)) != len(out_labels): logging.error("The same output label number was used " "for multiple inputs") # Create the resulting volume current_id = 0 resulting_labels = (np.ones_like(data_list[0], dtype=np.int) * args.background) for i in range(len(image_files)): # Add Given labels for each volume for index in filtered_ids_per_vol[i]: mask = data_list[i] == index resulting_labels[mask] = out_labels[current_id] current_id += 1 if np.count_nonzero(mask) == 0: logging.warning("Label {} was not in the volume".format(index)) # Save final combined volume nii = nib.Nifti1Image(resulting_labels, first_img.affine, first_img.header) nib.save(nii, args.output)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.in_hdf5, args.in_labels], args.force_labels_list) log_level = logging.WARNING if args.verbose: log_level = logging.INFO logging.basicConfig(level=log_level) coloredlogs.install(level=log_level) measures_to_compute = [] measures_output_filename = [] if args.volume: measures_to_compute.append('volume') measures_output_filename.append(args.volume) if args.streamline_count: measures_to_compute.append('streamline_count') measures_output_filename.append(args.streamline_count) if args.length: measures_to_compute.append('length') measures_output_filename.append(args.length) if args.similarity: measures_to_compute.append('similarity') measures_output_filename.append(args.similarity[1]) dict_maps_out_name = {} if args.maps is not None: for in_folder, out_name in args.maps: measures_to_compute.append(in_folder) dict_maps_out_name[in_folder] = out_name measures_output_filename.append(out_name) dict_metrics_out_name = {} if args.metrics is not None: for in_name, out_name in args.metrics: # Verify that all metrics are compatible with each other if not is_header_compatible(args.metrics[0][0], in_name): raise IOError('Metrics {} and {} do not share a compatible ' 'header'.format(args.metrics[0][0], in_name)) # This is necessary to support more than one map for weighting measures_to_compute.append((in_name, nib.load(in_name))) dict_metrics_out_name[in_name] = out_name measures_output_filename.append(out_name) dict_lesion_out_name = {} if args.lesion_load is not None: in_name = args.lesion_load[0] lesion_img = nib.load(in_name) lesion_data = get_data_as_mask(lesion_img, dtype=bool) lesion_atlas, _ = ndi.label(lesion_data) measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]), nib.Nifti1Image(lesion_atlas, lesion_img.affine))) out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy') out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy') out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy') dict_lesion_out_name[in_name+'vol'] = out_name_1 dict_lesion_out_name[in_name+'count'] = out_name_2 dict_lesion_out_name[in_name+'sc'] = out_name_3 measures_output_filename.extend([out_name_1, out_name_2, out_name_3]) assert_outputs_exist(parser, args, measures_output_filename) if not measures_to_compute: parser.error('No connectivity measures were selected, nothing ' 'to compute.') logging.info('The following measures will be computed and save: {}'.format( measures_output_filename)) if args.include_dps: if not os.path.isdir(args.include_dps): os.makedirs(args.include_dps) logging.info('data_per_streamline weighting is activated.') img_labels = nib.load(args.in_labels) data_labels = get_data_as_label(img_labels) if not args.force_labels_list: labels_list = np.unique(data_labels)[1:].tolist() else: labels_list = np.loadtxt( args.force_labels_list, dtype=np.int16).tolist() comb_list = list(itertools.combinations(labels_list, r=2)) if not args.no_self_connection: comb_list.extend(zip(labels_list, labels_list)) nbr_cpu = validate_nbr_processes(parser, args) measures_dict_list = [] if nbr_cpu == 1: for comb in comb_list: measures_dict_list.append(_processing_wrapper([args.in_hdf5, img_labels, comb, measures_to_compute, args.similarity, args.density_weighting, args.include_dps, args.min_lesion_vol])) else: pool = multiprocessing.Pool(nbr_cpu) measures_dict_list = pool.map(_processing_wrapper, zip(itertools.repeat(args.in_hdf5), itertools.repeat(img_labels), comb_list, itertools.repeat( measures_to_compute), itertools.repeat(args.similarity), itertools.repeat( args.density_weighting), itertools.repeat(args.include_dps), itertools.repeat(args.min_lesion_vol))) pool.close() pool.join() # Removing None entries (combinaisons that do not exist) # Fusing the multiprocessing output into a single dictionary measures_dict_list = [it for it in measures_dict_list if it is not None] if not measures_dict_list: raise ValueError('Empty matrix, no entries to save.') measures_dict = measures_dict_list[0] for dix in measures_dict_list[1:]: measures_dict.update(dix) if args.no_self_connection: total_elem = len(labels_list)**2 - len(labels_list) results_elem = len(measures_dict.keys())*2 - len(labels_list) else: total_elem = len(labels_list)**2 results_elem = len(measures_dict.keys())*2 logging.info('Out of {} possible nodes, {} contain value'.format( total_elem, results_elem)) # Filling out all the matrices (symmetric) in the order of labels_list nbr_of_measures = len(list(measures_dict.values())[0]) matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures)) for in_label, out_label in measures_dict: curr_node_dict = measures_dict[(in_label, out_label)] measures_ordering = list(curr_node_dict.keys()) for i, measure in enumerate(curr_node_dict): in_pos = labels_list.index(in_label) out_pos = labels_list.index(out_label) matrix[in_pos, out_pos, i] = curr_node_dict[measure] matrix[out_pos, in_pos, i] = curr_node_dict[measure] # Saving the matrices separatly with the specified name or dps for i, measure in enumerate(measures_ordering): if measure == 'volume': matrix_basename = args.volume elif measure == 'streamline_count': matrix_basename = args.streamline_count elif measure == 'length': matrix_basename = args.length elif measure == 'similarity': matrix_basename = args.similarity[1] elif measure in dict_metrics_out_name: matrix_basename = dict_metrics_out_name[measure] elif measure in dict_maps_out_name: matrix_basename = dict_maps_out_name[measure] elif measure in dict_lesion_out_name: matrix_basename = dict_lesion_out_name[measure] else: matrix_basename = measure np.save(matrix_basename, matrix[:, :, i])
def main(): # Callback required for FURY def keypress_callback(obj, _): key = obj.GetKeySym().lower() nonlocal clusters_linewidth, background_linewidth nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor iterator = len(accepted_streamlines) + len(rejected_streamlines) renwin = interactor_style.GetInteractor().GetRenderWindow() renderer = interactor_style.GetCurrentRenderer() if key == 'c' and iterator < len(sft_accepted_on_size): if show_curr_actor: renderer.rm(concat_streamlines_actor) renwin.Render() show_curr_actor = False logging.info('Streamlines rendering OFF') else: renderer.add(concat_streamlines_actor) renderer.rm(curr_streamlines_actor) renderer.add(curr_streamlines_actor) renwin.Render() show_curr_actor = True logging.info('Streamlines rendering ON') return if key == 'q': show_manager.exit() if iterator < len(sft_accepted_on_size): logging.warning( 'Early exit, everything remaining to be rejected.') return if key in ['a', 'r'] and iterator < len(sft_accepted_on_size): if key == 'a': accepted_streamlines.append(iterator) choices.append('a') logging.info('Accepted file %s', filename_accepted_on_size[iterator]) elif key == 'r': rejected_streamlines.append(iterator) choices.append('r') logging.info('Rejected file %s', filename_accepted_on_size[iterator]) iterator += 1 if key == 'z': if iterator > 0: last_choice = choices.pop() if last_choice == 'r': rejected_streamlines.pop() else: accepted_streamlines.pop() logging.info('Rewind on step.') iterator -= 1 else: logging.warning('Cannot rewind, first element.') if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size): renderer.rm(curr_streamlines_actor) curr_streamlines = sft_accepted_on_size[iterator].streamlines curr_streamlines_actor = actor.line(curr_streamlines, opacity=0.8, linewidth=clusters_linewidth) renderer.add(curr_streamlines_actor) if iterator == len(sft_accepted_on_size): print('No more cluster, press q to exit') renderer.rm(curr_streamlines_actor) renwin.Render() parser = _build_args_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_bundles) assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected]) if args.out_accepted_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_accepted_dir, create_dir=True) if args.out_rejected_dir: assert_output_dirs_exist_and_empty(parser, args, args.out_rejected_dir, create_dir=True) if args.verbose: logging.basicConfig(level=logging.INFO) if args.min_cluster_size < 1: parser.error('Minimum cluster size must be at least 1.') clusters_linewidth = args.clusters_linewidth background_linewidth = args.background_linewidth # To accelerate procedure, clusters can be discarded based on size # Concatenation is to give spatial context sft_accepted_on_size, filename_accepted_on_size = [], [] sft_rejected_on_size, filename_rejected_on_size = [], [] concat_streamlines = [] for filename in args.in_bundles: if not is_header_compatible(args.in_bundles[0], filename): return basename = os.path.basename(filename) sft = load_tractogram_with_reference(parser, args, filename, bbox_check=False) if len(sft) >= args.min_cluster_size: sft_accepted_on_size.append(sft) filename_accepted_on_size.append(basename) concat_streamlines.extend(sft.streamlines) else: logging.info('File %s has %s streamlines, automatically rejected.', filename, len(sft)) sft_rejected_on_size.append(sft) filename_rejected_on_size.append(basename) if not filename_accepted_on_size: parser.error('No cluster survived the cluster_size threshold.') logging.info('%s clusters to be classified.', len(sft_accepted_on_size)) # The clusters are sorted by size for simplicity/efficiency tuple_accepted = zip( *sorted(zip(sft_accepted_on_size, filename_accepted_on_size), key=lambda x: len(x[0]), reverse=True)) sft_accepted_on_size, filename_accepted_on_size = tuple_accepted # Initialize the actors, scene, window, observer concat_streamlines_actor = actor.line(concat_streamlines, colors=(1, 1, 1), opacity=args.background_opacity, linewidth=background_linewidth) curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines, opacity=0.8, linewidth=clusters_linewidth) scene = window.Scene() interactor_style = interactor.CustomInteractorStyle() show_manager = window.ShowManager(scene, size=(800, 800), reset_camera=False, interactor_style=interactor_style) scene.add(concat_streamlines_actor) scene.add(curr_streamlines_actor) interactor_style.AddObserver('KeyPressEvent', keypress_callback) # Lauch rendering and selection procedure choices, accepted_streamlines, rejected_streamlines = [], [], [] show_curr_actor = True show_manager.start() # Early exit means everything else is rejected missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size) len_accepted = len(sft_accepted_on_size) rejected_streamlines.extend(range(len_accepted - missing, len_accepted)) if missing > 0: logging.info('%s clusters automatically rejected from early exit', missing) # Save accepted clusters (by GUI) accepted_streamlines = save_clusters(sft_accepted_on_size, accepted_streamlines, args.out_accepted_dir, filename_accepted_on_size) accepted_sft = StatefulTractogram(accepted_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False) # Save rejected clusters (by GUI) rejected_streamlines = save_clusters(sft_accepted_on_size, rejected_streamlines, args.out_rejected_dir, filename_accepted_on_size) # Save rejected clusters (by size) rejected_streamlines.extend( save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)), args.out_rejected_dir, filename_rejected_on_size)) rejected_sft = StatefulTractogram(rejected_streamlines, sft_accepted_on_size[0], Space.RASMM) save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
""" affine, dimensions, voxel_sizes, voxel_order = get_reference_info( reference_anatomy) print(affine) print(dimensions) print(voxel_sizes) print(voxel_order) """ If you have a Trk file that was generated using a particular anatomy, to be considered valid all fields must correspond between the headers. It can be easily verified using this function, which also accept the same variety of input as ``get_reference_info`` """ print(is_header_compatible(reference_anatomy, bundles_filename[0])) """ If a TRK was generated with a valid header, but the reference NIFTI was lost a header can be generated to then generate a fake NIFTI file. If you wish to manually save Trk and Tck file using nibabel streamlines API for more freedom of action (not recommended for beginners) you can create a valid header using create_tractogram_header """ nifti_header = create_nifti_header(affine, dimensions, voxel_sizes) nib.save(nib.Nifti1Image(np.zeros(dimensions), affine, nifti_header), 'fake.nii.gz') nib.save(reference_anatomy, os.path.basename(ref_anat_filename)) """ Once loaded, no matter the original file format, the stateful tractogram is
def _processing_wrapper(args): hdf5_filename = args[0] labels_img = args[1] in_label, out_label = args[2] measures_to_compute = copy.copy(args[3]) if args[4] is not None: similarity_directory = args[4][0] weighted = args[5] include_dps = args[6] hdf5_file = h5py.File(hdf5_filename, 'r') key = '{}_{}'.format(in_label, out_label) if key not in hdf5_file: return streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key) affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img) measures_to_return = {} if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03) and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)): raise ValueError('Provided hdf5 have incompatible headers.') # Precompute to save one transformation, insert later if 'length' in measures_to_compute: streamlines_copy = list(streamlines) # scil_decompose_connectivity.py requires isotropic voxels mean_length = np.average(length(streamlines_copy)) * voxel_sizes[0] # If density is not required, do not compute it # Only required for volume, similarity and any metrics if not ((len(measures_to_compute) == 1 and ('length' in measures_to_compute or 'streamline_count' in measures_to_compute)) or (len(measures_to_compute) == 2 and ('length' in measures_to_compute and 'streamline_count' in measures_to_compute))): density = compute_tract_counts_map(streamlines, dimensions) if 'volume' in measures_to_compute: measures_to_return['volume'] = np.count_nonzero(density) * \ np.prod(voxel_sizes) measures_to_compute.remove('volume') if 'streamline_count' in measures_to_compute: measures_to_return['streamline_count'] = len(streamlines) measures_to_compute.remove('streamline_count') if 'length' in measures_to_compute: measures_to_return['length'] = mean_length measures_to_compute.remove('length') if 'similarity' in measures_to_compute and similarity_directory: density_sim = load_node_nifti(similarity_directory, in_label, out_label, labels_img) if density_sim is None: ba_vox = 0 else: ba_vox = compute_bundle_adjacency_voxel(density, density_sim) measures_to_return['similarity'] = ba_vox measures_to_compute.remove('similarity') for measure in measures_to_compute: if isinstance(measure, str) and os.path.isdir(measure): map_dirname = measure map_data = load_node_nifti(map_dirname, in_label, out_label, labels_img) measures_to_return[map_dirname] = np.average( map_data[map_data > 0]) elif isinstance(measure, tuple) and os.path.isfile(measure[0]): metric_filename = measure[0] metric_img = measure[1] if not is_header_compatible(metric_img, labels_img): logging.error('{} do not have a compatible header'.format( metric_filename)) raise IOError metric_data = metric_img.get_fdata(dtype=np.float64) if weighted: density = density / np.max(density) voxels_value = metric_data * density voxels_value = voxels_value[voxels_value > 0] else: voxels_value = metric_data[density > 0] measures_to_return[metric_filename] = np.average(voxels_value) if include_dps: for dps_key in hdf5_file[key].keys(): if dps_key not in ['data', 'offsets', 'lengths']: out_file = os.path.join(include_dps, dps_key) measures_to_return[out_file] = np.average( hdf5_file[key][dps_key]) return {(in_label, out_label): measures_to_return}
def load_tractogram(filename, reference, to_space=Space.RASMM, shifted_origin=False, bbox_valid_check=True, trk_header_check=True): """ Load the stateful tractogram from any format (trk, tck, vtk, fib, dpy) Parameters ---------- filename : string Filename with valid extension reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict), or 'same' if the input is a trk file. Reference that provides the spatial attribute. Typically a nifti-related object from the native diffusion used for streamlines generation to_space : Enum (dipy.io.stateful_tractogram.Space) Space to which the streamlines will be transformed after loading. shifted_origin : bool Information on the position of the origin, False is Trackvis standard, default (center of the voxel) True is NIFTI standard (corner of the voxel) bbox_valid_check : bool Verification for negative voxel coordinates or values above the volume dimensions. Default is True, to enforce valid file. trk_header_check : bool Verification that the reference has the same header as the spatial attributes as the input tractogram when a Trk is loaded Returns ------- output : StatefulTractogram The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']: logging.error('Output filename is not one of the supported format') return False if to_space not in Space: logging.error('Space MUST be one of the 3 choices (Enum)') return False if reference == 'same': if extension == '.trk': reference = filename else: logging.error('Reference must be provided, "same" is only ' + 'available for Trk file.') return False if trk_header_check and extension == '.trk': if not is_header_compatible(filename, reference): logging.error('Trk file header does not match the provided ' + 'reference') return False timer = time.time() data_per_point = None data_per_streamline = None if extension in ['.trk', '.tck']: tractogram_obj = nib.streamlines.load(filename).tractogram streamlines = tractogram_obj.streamlines if extension == '.trk': data_per_point = tractogram_obj.data_per_point data_per_streamline = tractogram_obj.data_per_streamline elif extension in ['.vtk', '.fib']: streamlines = load_vtk_streamlines(filename) elif extension in ['.dpy']: dpy_obj = Dpy(filename, mode='r') streamlines = list(dpy_obj.read_tracks()) dpy_obj.close() logging.debug('Load %s with %s streamlines in %s seconds', filename, len(streamlines), round(time.time() - timer, 3)) sft = StatefulTractogram(streamlines, reference, Space.RASMM, shifted_origin=shifted_origin, data_per_point=data_per_point, data_per_streamline=data_per_streamline) if to_space == Space.VOX: sft.to_vox() elif to_space == Space.VOXMM: sft.to_voxmm() if bbox_valid_check and not sft.is_bbox_in_vox_valid(): raise ValueError('Bounding box is not valid in voxel space, cannot ' + 'load a valid file if some coordinates are invalid.' + 'Please set bbox_valid_check to False and then use' + 'the function remove_invalid_streamlines to discard' + 'invalid streamlines.') return sft
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, args.in_tractogram) assert_outputs_exist(parser, args, args.out_tractogram) if args.verbose: logging.basicConfig(level=logging.DEBUG) roi_opt_list, only_filtering_list = prepare_filtering_list(parser, args) o_dict = {} sft = load_tractogram_with_reference(parser, args, args.in_tractogram) # Streamline count before filtering o_dict['streamline_count_before_filtering'] = len(sft.streamlines) for i, roi_opt in enumerate(roi_opt_list): curr_dict = {} # Atlas needs an extra argument (value in the LUT) if roi_opt[0] == 'atlas_roi': filter_type, filter_arg, filter_arg_2, \ filter_mode, filter_criteria = roi_opt else: filter_type, filter_arg, filter_mode, filter_criteria = roi_opt curr_dict['filename'] = os.path.abspath(filter_arg) curr_dict['type'] = filter_type curr_dict['mode'] = filter_mode curr_dict['criteria'] = filter_criteria is_exclude = False if filter_criteria == 'include' else True if filter_type == 'drawn_roi' or filter_type == 'atlas_roi': img = nib.load(filter_arg) if not is_header_compatible(img, sft): parser.error('Headers from the tractogram and the mask are ' 'not compatible.') if filter_type == 'drawn_roi': mask = get_data_as_mask(img) else: atlas = get_data_as_label(img) mask = np.zeros(atlas.shape, dtype=np.uint16) mask[atlas == int(filter_arg_2)] = 1 filtered_sft, indexes = filter_grid_roi(sft, mask, filter_mode, is_exclude) # For every case, the input number must be greater or equal to 0 and # below the dimension, since this is a voxel space operation elif filter_type in ['x_plane', 'y_plane', 'z_plane']: filter_arg = int(filter_arg) _, dim, _, _ = sft.space_attributes mask = np.zeros(dim, dtype=np.int16) error_msg = None if filter_type == 'x_plane': if 0 <= filter_arg < dim[0]: mask[filter_arg, :, :] = 1 else: error_msg = 'X plane ' + str(filter_arg) elif filter_type == 'y_plane': if 0 <= filter_arg < dim[1]: mask[:, filter_arg, :] = 1 else: error_msg = 'Y plane ' + str(filter_arg) elif filter_type == 'z_plane': if 0 <= filter_arg < dim[2]: mask[:, :, filter_arg] = 1 else: error_msg = 'Z plane ' + str(filter_arg) if error_msg: parser.error('{} is not valid according to the ' 'tractogram header.'.format(error_msg)) filtered_sft, indexes = filter_grid_roi(sft, mask, filter_mode, is_exclude) elif filter_type == 'bdo': geometry, radius, center = read_info_from_mb_bdo(filter_arg) if geometry == 'Ellipsoid': filtered_sft, indexes = filter_ellipsoid(sft, radius, center, filter_mode, is_exclude) elif geometry == 'Cuboid': filtered_sft, indexes = filter_cuboid(sft, radius, center, filter_mode, is_exclude) logging.debug('The filtering options {0} resulted in ' '{1} streamlines'.format(roi_opt, len(filtered_sft))) sft = filtered_sft if only_filtering_list: filtering_Name = 'Filter_' + str(i) curr_dict['streamline_count_after_filtering'] = len(sft.streamlines) o_dict[filtering_Name] = curr_dict # Streamline count after filtering o_dict['streamline_count_final_filtering'] = len(sft.streamlines) if args.display_counts: print(json.dumps(o_dict, indent=args.indent)) if not filtered_sft: if args.no_empty: logging.debug("The file {} won't be written (0 streamline)".format( args.out_tractogram)) return logging.debug('The file {} contains 0 streamline'.format( args.out_tractogram)) save_tractogram(sft, args.out_tractogram)