def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist( parser, [args.moving_tractogram, args.target_file, args.transformation]) assert_outputs_exist(parser, args, args.out_tractogram) moving_sft = load_tractogram_with_reference(parser, args, args.moving_tractogram, bbox_check=False) transfo = np.loadtxt(args.transformation) if args.inverse: transfo = np.linalg.inv(transfo) moved_streamlines = transform_streamlines(moving_sft.streamlines, transfo) new_sft = StatefulTractogram( moved_streamlines, args.target_file, Space.RASMM, data_per_point=moving_sft.data_per_point, data_per_streamline=moving_sft.data_per_streamline) if args.remove_invalid: ori_len = len(new_sft) new_sft.remove_invalid_streamlines() logging.warning('Removed {} invalid streamlines.'.format(ori_len - len(new_sft))) save_tractogram(new_sft, args.out_tractogram) elif args.keep_invalid: if not new_sft.is_bbox_in_vox_valid(): logging.warning('Saving tractogram with invalid streamlines.') save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) else: save_tractogram(new_sft, args.out_tractogram)
def main(): parser = _build_arg_parser() args = parser.parse_args() assert_inputs_exist(parser, [args.moving_tractogram, args.target_file, args.deformation]) assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram_with_reference(parser, args, args.moving_tractogram, bbox_check=False) deformation = nib.load(args.deformation) deformation_data = np.squeeze(deformation.get_fdata()) if not is_header_compatible(sft, deformation): parser.error('Input tractogram/reference do not have the same spatial ' 'attribute as the deformation field.') # Warning: Apply warp in-place moved_streamlines = warp_streamlines(sft, deformation_data) new_sft = StatefulTractogram(moved_streamlines, args.target_file, Space.RASMM, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) if args.remove_invalid: ori_len = len(new_sft) new_sft.remove_invalid_streamlines() logging.warning('Removed {} invalid streamlines.'.format( ori_len - len(new_sft))) save_tractogram(new_sft, args.out_tractogram) elif args.keep_invalid: if not new_sft.is_bbox_in_vox_valid(): logging.warning('Saving tractogram with invalid streamlines.') save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) else: save_tractogram(new_sft, args.out_tractogram)
def direct_streamline_norm( streams, fa_path, ap_path, dir_path, track_type, target_samples, conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, labels_im_file, uatlas, labels, coords, norm, binary, atlas_mni, basedir_path, curv_thr_list, step_list, directget, min_length, error_margin, t1_aligned_mni ): """ A Function to perform normalization of streamlines tracked in native diffusion space to an MNI-space template. Parameters ---------- streams : str File path to save streamline array sequence in .trk format. fa_path : str File path to FA Nifti1Image. ap_path : str File path to the anisotropic power Nifti1Image. dir_path : str Path to directory containing subject derivative data for a given pynets run. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). target_samples : int Total number of streamline samples specified to generate streams. conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). network : str Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_size : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone network' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. labels_im_file : str File path to atlas parcellation Nifti1Image aligned to dwi space. uatlas : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_mni : str File path to atlas parcellation Nifti1Image in T1w-warped MNI space. basedir_path : str Path to directory to output direct-streamline normalized temp files and outputs. curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). min_length : int Minimum fiber length threshold in mm to restrict tracking. t1_aligned_mni : str File path to the T1w Nifti1Image in template MNI space. Returns ------- streams_warp : str File path to normalized streamline array sequence in .trk format. dir_path : str Path to directory containing subject derivative data for a given pynets run. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). target_samples : int Total number of streamline samples specified to generate streams. conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). network : str Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_size : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone network' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. uatlas : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_mni : str File path to atlas parcellation Nifti1Image in T1w-warped MNI space. directget : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). warped_fa : str File path to MNI-space warped FA Nifti1Image. min_length : int Minimum fiber length threshold in mm to restrict tracking. References ---------- .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of different spatial normalization approaches on tractography and structural brain networks. Network Neuroscience, 1-19. """ import sys import gc from dipy.tracking.streamline import transform_streamlines from pynets.registration import reg_utils as regutils # from pynets.plotting import plot_gen import pkg_resources import yaml import os.path as op from pynets.registration.reg_utils import vdc from nilearn.image import resample_to_img from dipy.io.streamline import load_tractogram from dipy.tracking import utils from dipy.tracking._utils import _mapping_to_voxel from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin from dipy.io.streamline import save_tractogram # from pynets.core.utils import missing_elements with open( pkg_resources.resource_filename("pynets", "runconfig.yaml"), "r" ) as stream: try: hardcoded_params = yaml.load(stream) run_dsn = hardcoded_params['tracking']["DSN"][0] except FileNotFoundError as e: import sys print(e, "Failed to parse runconfig.yaml") exit(1) stream.close() if run_dsn is True: dsn_dir = f"{basedir_path}/dmri_reg/DSN" if not op.isdir(dsn_dir): os.mkdir(dsn_dir) namer_dir = f"{dir_path}/tractography" if not op.isdir(namer_dir): os.mkdir(namer_dir) atlas_img = nib.load(labels_im_file) # Run SyN and normalize streamlines fa_img = nib.load(fa_path) vox_size = fa_img.header.get_zooms()[0] template_path = pkg_resources.resource_filename( "pynets", f"templates/FA_{int(vox_size)}mm.nii.gz" ) if sys.platform.startswith('win') is False: try: template_img = nib.load(template_path) except indexed_gzip.ZranError as e: print(e, f"\nCannot load FA template. Do you have git-lfs " f"installed?") sys.exit(1) else: try: template_img = nib.load(template_path) except ImportError as e: print(e, f"\nCannot load FA template. Do you have git-lfs " f"installed?") sys.exit(1) uatlas_mni_img = nib.load(atlas_mni) t1_aligned_mni_img = nib.load(t1_aligned_mni) brain_mask = np.asarray(t1_aligned_mni_img.dataobj).astype("bool") streams_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % ( namer_dir, "/streamlines_mni_", "%s" % (network + "_" if network is not None else ""), "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None else ""), conn_model, "_", target_samples, "%s" % ( "%s%s" % ("_" + str(node_size), "mm_") if ((node_size != "parc") and (node_size is not None)) else "_" ), "curv", str(curv_thr_list).replace(", ", "_"), "step", str(step_list).replace(", ", "_"), "tracktype-", track_type, "_directget-", directget, "_minlength-", min_length, "_tol-", error_margin, ".trk", ) density_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % ( namer_dir, "/density_map_mni_", "%s" % (network + "_" if network is not None else ""), "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None else ""), conn_model, "_", target_samples, "%s" % ( "%s%s" % ("_" + str(node_size), "mm_") if ((node_size != "parc") and (node_size is not None)) else "_" ), "curv", str(curv_thr_list).replace(", ", "_"), "step", str(step_list).replace(", ", "_"), "tracktype-", track_type, "_directget-", directget, "_minlength-", min_length, "_tol-", error_margin, ".nii.gz", ) # streams_warp_png = '/tmp/dsn.png' # SyN FA->Template [mapping, affine_map, warped_fa] = regutils.wm_syn( template_path, fa_path, t1_aligned_mni, ap_path, dsn_dir ) tractogram = load_tractogram( streams, fa_img, to_origin=Origin.NIFTI, to_space=Space.VOXMM, bbox_valid_check=False, ) fa_img.uncache() streamlines = tractogram.streamlines warped_fa_img = nib.load(warped_fa) warped_fa_affine = warped_fa_img.affine warped_fa_shape = warped_fa_img.shape streams_in_curr_grid = transform_streamlines( streamlines, warped_fa_affine) # Create isocenter mapping where we anchor the origin transformation # affine to the corner of the FOV by scaling x, y, z offsets according # to a multiplicative van der Corput sequence with a base value equal # to the voxel resolution [x_mul, y_mul, z_mul] = [vdc(i, vox_size) for i in range(1, 4)] ref_grid_aff = vox_size * np.eye(4) ref_grid_aff[3][3] = 1 streams_final_filt = [] i = 0 # Test for various types of voxel-grid configurations combs = [(-x_mul, -y_mul, -z_mul), (-x_mul, -y_mul, z_mul), (-x_mul, y_mul, -z_mul), (x_mul, -y_mul, -z_mul), (x_mul, y_mul, z_mul)] while len(streams_final_filt)/len(streams_in_curr_grid) < 0.90: print(f"Warping streamlines to MNI space. Attempt {i}...") print(len(streams_final_filt)/len(streams_in_curr_grid)) adjusted_affine = affine_map.affine.copy() if i > len(combs) - 1: raise ValueError('DSN failed. Header orientation ' 'information may be corrupted. ' 'Is your dataset oblique?') adjusted_affine[0][3] = adjusted_affine[0][3] * combs[i][0] adjusted_affine[1][3] = adjusted_affine[1][3] * combs[i][1] adjusted_affine[2][3] = adjusted_affine[2][3] * combs[i][2] streams_final_filt = regutils.warp_streamlines(adjusted_affine, ref_grid_aff, mapping, warped_fa_img, streams_in_curr_grid, brain_mask) i += 1 # Remove streamlines with negative voxel indices lin_T, offset = _mapping_to_voxel(np.eye(4)) streams_final_filt_final = [] for sl in streams_final_filt: inds = np.dot(sl, lin_T) inds += offset if not inds.min().round(decimals=6) < 0: streams_final_filt_final.append(sl) # Save streamlines stf = StatefulTractogram( streams_final_filt_final, reference=uatlas_mni_img, space=Space.VOXMM, origin=Origin.NIFTI, ) stf.remove_invalid_streamlines() streams_final_filt_final = stf.streamlines save_tractogram(stf, streams_mni, bbox_valid_check=True) warped_fa_img.uncache() # DSN QC plotting # plot_gen.show_template_bundles(streams_final_filt_final, atlas_mni, # streams_warp_png) plot_gen.show_template_bundles(streamlines, # fa_path, streams_warp_png) # Create and save MNI density map nib.save( nib.Nifti1Image( utils.density_map( streams_final_filt_final, affine=np.eye(4), vol_dims=warped_fa_shape), warped_fa_affine, ), density_mni, ) # Map parcellation from native space back to MNI-space and create an # 'uncertainty-union' parcellation with original mni-space uatlas warped_uatlas = affine_map.transform_inverse( mapping.transform( np.asarray(atlas_img.dataobj).astype("int"), interpolation="nearestneighbour", ), interp="nearest", ) atlas_img.uncache() warped_uatlas_img_res_data = np.asarray( resample_to_img( nib.Nifti1Image(warped_uatlas, affine=warped_fa_affine), uatlas_mni_img, interpolation="nearest", clip=False, ).dataobj ) uatlas_mni_data = np.asarray(uatlas_mni_img.dataobj) uatlas_mni_img.uncache() overlap_mask = np.invert( warped_uatlas_img_res_data.astype("bool") * uatlas_mni_data.astype("bool")) os.makedirs(f"{dir_path}/parcellations", exist_ok=True) atlas_mni = f"{dir_path}/parcellations/" \ f"{op.basename(uatlas).split('.nii')[0]}_liberal.nii.gz" nib.save( nib.Nifti1Image( warped_uatlas_img_res_data * overlap_mask.astype("int") + uatlas_mni_data * overlap_mask.astype("int") + np.invert(overlap_mask).astype("int") * warped_uatlas_img_res_data.astype("int"), affine=uatlas_mni_img.affine, ), atlas_mni, ) del ( tractogram, streamlines, warped_uatlas_img_res_data, uatlas_mni_data, overlap_mask, stf, streams_final_filt_final, streams_final_filt, streams_in_curr_grid, brain_mask, ) gc.collect() assert len(coords) == len(labels) else: print( "Skipping Direct Streamline Normalization (DSN). Will proceed to " "define fiber connectivity in native diffusion space...") streams_mni = streams warped_fa = fa_path atlas_mni = labels_im_file return ( streams_mni, dir_path, track_type, target_samples, conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, atlas_mni, directget, warped_fa, min_length, error_margin )
def transform_warp_streamlines(sft, linear_transfo, target, inverse=False, deformation_data=None, remove_invalid=True, cut_invalid=False): # TODO rename transform_warp_sft """ Transform tractogram using a affine Subsequently apply a warp from antsRegistration (optional). Remove/Cut invalid streamlines to preserve sft validity. Parameters ---------- sft: StatefulTractogram Stateful tractogram object containing the streamlines to transform. linear_transfo: numpy.ndarray Linear transformation matrix to apply to the tractogram. target: Nifti filepath, image object, header Final reference for the tractogram after registration. inverse: boolean Apply the inverse linear transformation. deformation_data: np.ndarray 4D array containing a 3D displacement vector in each voxel. remove_invalid: boolean Remove the streamlines landing out of the bounding box. cut_invalid: boolean Cut invalid streamlines rather than removing them. Keep the longest segment only. Return ---------- new_sft : StatefulTractogram """ sft.to_rasmm() sft.to_center() if inverse: linear_transfo = np.linalg.inv(linear_transfo) streamlines = transform_streamlines(sft.streamlines, linear_transfo) if deformation_data is not None: affine, _, _, _ = get_reference_info(target) # Because of duplication, an iteration over chunks of points is # necessary for a big dataset (especially if not compressed) streamlines = ArraySequence(streamlines) nb_points = len(streamlines._data) cur_position = 0 chunk_size = 1000000 nb_iteration = int(np.ceil(nb_points / chunk_size)) inv_affine = np.linalg.inv(affine) while nb_iteration > 0: max_position = min(cur_position + chunk_size, nb_points) points = streamlines._data[cur_position:max_position] # To access the deformation information, we need to go in VOX space # No need for corner shift since we are doing interpolation cur_points_vox = np.array(transform_streamlines( points, inv_affine)).T x_def = map_coordinates(deformation_data[..., 0], cur_points_vox.tolist(), order=1) y_def = map_coordinates(deformation_data[..., 1], cur_points_vox.tolist(), order=1) z_def = map_coordinates(deformation_data[..., 2], cur_points_vox.tolist(), order=1) # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs final_points = np.array([-1 * x_def, -1 * y_def, z_def]) final_points += np.array(points).T streamlines._data[cur_position:max_position] = final_points.T cur_position = max_position nb_iteration -= 1 new_sft = StatefulTractogram(streamlines, target, Space.RASMM, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline) if cut_invalid: new_sft, _ = cut_invalid_streamlines(new_sft) elif remove_invalid: new_sft.remove_invalid_streamlines() return new_sft
def empty_remove_invalid(): sft = StatefulTractogram([], filepath_dix['gs.nii'], Space.VOX) sft.remove_invalid_streamlines() assert_array_equal([], sft.streamlines.data)
def _run_interface(self, runtime): import gc import os import time import os.path as op from dipy.io import load_pickle from colorama import Fore, Style from dipy.data import get_sphere from pynets.core import utils from pynets.core.utils import load_runconfig from pynets.dmri.estimation import reconstruction from pynets.dmri.track import ( create_density_map, track_ensemble, ) from dipy.io.stateful_tractogram import Space, StatefulTractogram, \ Origin from dipy.io.streamline import save_tractogram from nipype.utils.filemanip import copyfile, fname_presuffix hardcoded_params = load_runconfig() use_life = hardcoded_params['tracking']["use_life"][0] roi_neighborhood_tol = hardcoded_params['tracking'][ "roi_neighborhood_tol"][0] sphere = hardcoded_params['tracking']["sphere"][0] target_samples = hardcoded_params['tracking']["tracking_samples"][0] dir_path = utils.do_dir_path(self.inputs.atlas, os.path.dirname(self.inputs.dwi_file)) namer_dir = "{}/tractography".format(dir_path) if not os.path.isdir(namer_dir): os.makedirs(namer_dir, exist_ok=True) # Load diffusion data dwi_file_tmp_path = fname_presuffix(self.inputs.dwi_file, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.dwi_file, dwi_file_tmp_path, copy=True, use_hardlink=False) dwi_img = nib.load(dwi_file_tmp_path, mmap=True) dwi_data = dwi_img.get_fdata(dtype=np.float32) # Load FA data fa_file_tmp_path = fname_presuffix(self.inputs.fa_path, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.fa_path, fa_file_tmp_path, copy=True, use_hardlink=False) fa_img = nib.load(fa_file_tmp_path, mmap=True) labels_im_file_tmp_path = fname_presuffix(self.inputs.labels_im_file, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.labels_im_file, labels_im_file_tmp_path, copy=True, use_hardlink=False) # Load B0 mask B0_mask_tmp_path = fname_presuffix(self.inputs.B0_mask, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.B0_mask, B0_mask_tmp_path, copy=True, use_hardlink=False) streams = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % ( runtime.cwd, "/streamlines_", "%s" % (self.inputs.subnet + "_" if self.inputs.subnet is not None else ""), "%s" % (op.basename(self.inputs.roi).split(".")[0] + "_" if self.inputs.roi is not None else ""), self.inputs.conn_model, "_", target_samples, "_", "%s" % ("%s%s" % (self.inputs.node_radius, "mm_") if ((self.inputs.node_radius != "parc") and (self.inputs.node_radius is not None)) else "parc_"), "curv-", str(self.inputs.curv_thr_list).replace(", ", "_"), "_step-", str(self.inputs.step_list).replace(", ", "_"), "_traversal-", self.inputs.traversal, "_minlength-", self.inputs.min_length, ".trk", ) if os.path.isfile(f"{namer_dir}/{op.basename(streams)}"): from dipy.io.streamline import load_tractogram copyfile( f"{namer_dir}/{op.basename(streams)}", streams, copy=True, use_hardlink=False, ) tractogram = load_tractogram( streams, fa_img, bbox_valid_check=False, ) streamlines = tractogram.streamlines # Create streamline density map try: [dir_path, dm_path] = create_density_map( fa_img, dir_path, streamlines, self.inputs.conn_model, self.inputs.node_radius, self.inputs.curv_thr_list, self.inputs.step_list, self.inputs.subnet, self.inputs.roi, self.inputs.traversal, self.inputs.min_length, namer_dir, ) except BaseException: print('Density map failed. Check tractography output.') dm_path = None del streamlines, tractogram fa_img.uncache() dwi_img.uncache() gc.collect() self._results["dm_path"] = dm_path self._results["streams"] = streams recon_path = None else: # Fit diffusion model # Save reconstruction to .npy recon_path = "%s%s%s%s%s%s%s%s" % ( runtime.cwd, "/reconstruction_", "%s" % (self.inputs.subnet + "_" if self.inputs.subnet is not None else ""), "%s" % (op.basename(self.inputs.roi).split(".")[0] + "_" if self.inputs.roi is not None else ""), self.inputs.conn_model, "_", "%s" % ("%s%s" % (self.inputs.node_radius, "mm") if ((self.inputs.node_radius != "parc") and (self.inputs.node_radius is not None)) else "parc"), ".hdf5", ) gtab_file_tmp_path = fname_presuffix(self.inputs.gtab_file, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.gtab_file, gtab_file_tmp_path, copy=True, use_hardlink=False) gtab = load_pickle(gtab_file_tmp_path) # Only re-run the reconstruction if we have to if not os.path.isfile(f"{namer_dir}/{op.basename(recon_path)}"): import h5py model = reconstruction( self.inputs.conn_model, gtab, dwi_data, B0_mask_tmp_path, )[0] with h5py.File(recon_path, 'w') as hf: hf.create_dataset("reconstruction", data=model.astype('float32'), dtype='f4') hf.close() copyfile( recon_path, f"{namer_dir}/{op.basename(recon_path)}", copy=True, use_hardlink=False, ) time.sleep(2) del model elif os.path.getsize(f"{namer_dir}/{op.basename(recon_path)}") > 0: print(f"Found existing reconstruction with " f"{self.inputs.conn_model}. Loading...") copyfile( f"{namer_dir}/{op.basename(recon_path)}", recon_path, copy=True, use_hardlink=False, ) time.sleep(5) else: import h5py model = reconstruction( self.inputs.conn_model, gtab, dwi_data, B0_mask_tmp_path, )[0] with h5py.File(recon_path, 'w') as hf: hf.create_dataset("reconstruction", data=model.astype('float32'), dtype='f4') hf.close() copyfile( recon_path, f"{namer_dir}/{op.basename(recon_path)}", copy=True, use_hardlink=False, ) time.sleep(5) del model dwi_img.uncache() del dwi_data # Load atlas wm-gm interface reduced version for seeding labels_im_file_tmp_path_wm_gm_int = fname_presuffix( self.inputs.labels_im_file_wm_gm_int, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.labels_im_file_wm_gm_int, labels_im_file_tmp_path_wm_gm_int, copy=True, use_hardlink=False) t1w2dwi_tmp_path = fname_presuffix(self.inputs.t1w2dwi, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.t1w2dwi, t1w2dwi_tmp_path, copy=True, use_hardlink=False) gm_in_dwi_tmp_path = fname_presuffix(self.inputs.gm_in_dwi, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.gm_in_dwi, gm_in_dwi_tmp_path, copy=True, use_hardlink=False) vent_csf_in_dwi_tmp_path = fname_presuffix( self.inputs.vent_csf_in_dwi, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.vent_csf_in_dwi, vent_csf_in_dwi_tmp_path, copy=True, use_hardlink=False) wm_in_dwi_tmp_path = fname_presuffix(self.inputs.wm_in_dwi, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.wm_in_dwi, wm_in_dwi_tmp_path, copy=True, use_hardlink=False) if self.inputs.waymask: waymask_tmp_path = fname_presuffix(self.inputs.waymask, suffix="_tmp", newpath=runtime.cwd) copyfile(self.inputs.waymask, waymask_tmp_path, copy=True, use_hardlink=False) else: waymask_tmp_path = None # Iteratively build a list of streamlines for each ROI while # tracking print(f"{Fore.GREEN}Target streamlines per iteration: " f"{Fore.BLUE} " f"{target_samples}") print(Style.RESET_ALL) print(f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} " f"{self.inputs.curv_thr_list}") print(Style.RESET_ALL) print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} " f"{self.inputs.step_list}") print(Style.RESET_ALL) print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} " f"{self.inputs.track_type}") print(Style.RESET_ALL) if self.inputs.traversal == "prob": print(f"{Fore.GREEN}Direction-getting type: {Fore.BLUE}" f"Probabilistic") elif self.inputs.traversal == "clos": print(f"{Fore.GREEN}Direction-getting type: " f"{Fore.BLUE}Closest Peak") elif self.inputs.traversal == "det": print(f"{Fore.GREEN}Direction-getting type: " f"{Fore.BLUE}Deterministic Maximum") else: raise ValueError("Direction-getting type not recognized!") print(Style.RESET_ALL) # Commence Ensemble Tractography try: streamlines = track_ensemble( target_samples, labels_im_file_tmp_path_wm_gm_int, labels_im_file_tmp_path, recon_path, get_sphere(sphere), self.inputs.traversal, self.inputs.curv_thr_list, self.inputs.step_list, self.inputs.track_type, self.inputs.maxcrossing, int(roi_neighborhood_tol), self.inputs.min_length, waymask_tmp_path, B0_mask_tmp_path, t1w2dwi_tmp_path, gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path, wm_in_dwi_tmp_path, self.inputs.tiss_class) gc.collect() except BaseException as w: print(f"\n{Fore.RED}Tractography failed: {w}") print(Style.RESET_ALL) streamlines = None if streamlines is not None: # import multiprocessing # from pynets.core.utils import kill_process_family # return kill_process_family(int( # multiprocessing.current_process().pid)) # Linear Fascicle Evaluation (LiFE) if use_life is True: print('Using LiFE to evaluate streamline plausibility...') from pynets.dmri.utils import \ evaluate_streamline_plausibility dwi_img = nib.load(dwi_file_tmp_path) dwi_data = dwi_img.get_fdata(dtype=np.float32) orig_count = len(streamlines) if self.inputs.waymask: mask_data = nib.load(waymask_tmp_path).get_fdata( ).astype('bool').astype('int') else: mask_data = nib.load(wm_in_dwi_tmp_path).get_fdata( ).astype('bool').astype('int') try: streamlines = evaluate_streamline_plausibility( dwi_data, gtab, mask_data, streamlines, sphere=sphere) except BaseException: print(f"Linear Fascicle Evaluation failed. " f"Visually checking streamlines output " f"{namer_dir}/{op.basename(streams)} is " f"recommended.") if len(streamlines) < 0.5 * orig_count: raise ValueError('LiFE revealed no plausible ' 'streamlines in the tractogram!') del dwi_data, mask_data # Save streamlines to trk stf = StatefulTractogram(streamlines, fa_img, origin=Origin.NIFTI, space=Space.VOXMM) stf.remove_invalid_streamlines() save_tractogram( stf, streams, ) del stf copyfile( streams, f"{namer_dir}/{op.basename(streams)}", copy=True, use_hardlink=False, ) # Create streamline density map try: [dir_path, dm_path] = create_density_map( dwi_img, dir_path, streamlines, self.inputs.conn_model, self.inputs.node_radius, self.inputs.curv_thr_list, self.inputs.step_list, self.inputs.subnet, self.inputs.roi, self.inputs.traversal, self.inputs.min_length, namer_dir, ) except BaseException: print('Density map failed. Check tractography output.') dm_path = None del streamlines dwi_img.uncache() gc.collect() self._results["dm_path"] = dm_path self._results["streams"] = streams else: self._results["streams"] = None self._results["dm_path"] = None tmp_files = [ gtab_file_tmp_path, wm_in_dwi_tmp_path, gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path, t1w2dwi_tmp_path ] for j in tmp_files: if j is not None: if os.path.isfile(j): os.system(f"rm -f {j} &") self._results["track_type"] = self.inputs.track_type self._results["conn_model"] = self.inputs.conn_model self._results["dir_path"] = dir_path self._results["subnet"] = self.inputs.subnet self._results["node_radius"] = self.inputs.node_radius self._results["dens_thresh"] = self.inputs.dens_thresh self._results["ID"] = self.inputs.ID self._results["roi"] = self.inputs.roi self._results["min_span_tree"] = self.inputs.min_span_tree self._results["disp_filt"] = self.inputs.disp_filt self._results["parc"] = self.inputs.parc self._results["prune"] = self.inputs.prune self._results["atlas"] = self.inputs.atlas self._results["parcellation"] = self.inputs.parcellation self._results["labels"] = self.inputs.labels self._results["coords"] = self.inputs.coords self._results["norm"] = self.inputs.norm self._results["binary"] = self.inputs.binary self._results["atlas_t1w"] = self.inputs.atlas_t1w self._results["curv_thr_list"] = self.inputs.curv_thr_list self._results["step_list"] = self.inputs.step_list self._results["fa_path"] = fa_file_tmp_path self._results["traversal"] = self.inputs.traversal self._results["labels_im_file"] = labels_im_file_tmp_path self._results["min_length"] = self.inputs.min_length tmp_files = [B0_mask_tmp_path, dwi_file_tmp_path] for j in tmp_files: if j is not None: if os.path.isfile(j): os.system(f"rm -f {j} &") # Exercise caution when deleting copied recon_path # if recon_path is not None: # if os.path.isfile(recon_path): # os.remove(recon_path) return runtime
def direct_streamline_norm(streams, fa_path, ap_path, dir_path, track_type, conn_model, subnet, node_radius, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, labels_im_file, parcellation, labels, coords, norm, binary, atlas_t1w, basedir_path, curv_thr_list, step_list, traversal, min_length, t1w_brain, run_dsn=False): """ A Function to perform normalization of streamlines tracked in native diffusion space to an MNI-space template. Parameters ---------- streams : str File path to save streamline array sequence in .trk format. fa_path : str File path to FA Nifti1Image. ap_path : str File path to the anisotropic power Nifti1Image. dir_path : str Path to directory containing subject derivative data for a given pynets run. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). subnet : str Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_radius : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone subnet' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. labels_im_file : str File path to atlas parcellation Nifti1Image aligned to dwi space. parcellation : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_t1w : str File path to atlas parcellation Nifti1Image in T1w-conformed space. basedir_path : str Path to directory to output direct-streamline normalized temp files and outputs. curv_thr_list : list List of integer curvature thresholds used to perform ensemble tracking. step_list : list List of float step-sizes used to perform ensemble tracking. traversal : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). min_length : int Minimum fiber length threshold in mm to restrict tracking. t1w_brain : str File path to the T1w Nifti1Image. Returns ------- streams_warp : str File path to normalized streamline array sequence in .trk format. dir_path : str Path to directory containing subject derivative data for a given pynets run. track_type : str Tracking algorithm used (e.g. 'local' or 'particle'). conn_model : str Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd'). subnet : str Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the study of brain subgraphs. node_radius : int Spherical centroid node size in the case that coordinate-based centroids are used as ROI's for tracking. dens_thresh : bool Indicates whether a target graph density is to be used as the basis for thresholding. ID : str A subject id or other unique identifier. roi : str File path to binarized/boolean region-of-interest Nifti1Image file. min_span_tree : bool Indicates whether local thresholding from the Minimum Spanning Tree should be used. disp_filt : bool Indicates whether local thresholding using a disparity filter and 'backbone subnet' should be used. parc : bool Indicates whether to use parcels instead of coordinates as ROI nodes. prune : bool Indicates whether to prune final graph of disconnected nodes/isolates. atlas : str Name of atlas parcellation used. parcellation : str File path to atlas parcellation Nifti1Image in MNI template space. labels : list List of string labels corresponding to graph nodes. coords : list List of (x, y, z) tuples corresponding to a coordinate atlas used or which represent the center-of-mass of each parcellation node. norm : int Indicates method of normalizing resulting graph. binary : bool Indicates whether to binarize resulting graph edges to form an unweighted graph. atlas_for_streams : str File path to atlas parcellation Nifti1Image in the same morphological space as the streamlines. traversal : str The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped), and prob (probabilistic). warped_fa : str File path to MNI-space warped FA Nifti1Image. min_length : int Minimum fiber length threshold in mm to restrict tracking. References ---------- .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of different spatial normalization approaches on tractography and structural brain subnets. subnet Neuroscience, 1-19. """ import gc from dipy.tracking.streamline import transform_streamlines from pynets.registration import utils as regutils from pynets.plotting.brain import show_template_bundles import os.path as op from dipy.io.streamline import load_tractogram from dipy.tracking._utils import _mapping_to_voxel from dipy.tracking.utils import density_map from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin from dipy.io.streamline import save_tractogram if run_dsn is True: dsn_dir = f"{basedir_path}/dmri_reg/DSN" if not op.isdir(dsn_dir): os.mkdir(dsn_dir) namer_dir = f"{dir_path}/tractography" if not op.isdir(namer_dir): os.mkdir(namer_dir) atlas_img = nib.load(labels_im_file) # Run SyN and normalize streamlines fa_img = nib.load(fa_path) atlas_t1w_img = nib.load(atlas_t1w) t1w_brain_img = nib.load(t1w_brain) brain_mask = np.asarray(t1w_brain_img.dataobj).astype("bool") streams_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % ( namer_dir, "/streamlines_t1w_", "%s" % (subnet + "_" if subnet is not None else ""), "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None else ""), conn_model, "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if ((node_radius != "parc") and (node_radius is not None)) else "_"), "curv", str(curv_thr_list).replace(", ", "_"), "step", str(step_list).replace(", ", "_"), "tracktype-", track_type, "_traversal-", traversal, "_minlength-", min_length, ".trk", ) density_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % ( namer_dir, "/density_map_t1w_", "%s" % (subnet + "_" if subnet is not None else ""), "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None else ""), conn_model, "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if ((node_radius != "parc") and (node_radius is not None)) else "_"), "curv", str(curv_thr_list).replace(", ", "_"), "step", str(step_list).replace(", ", "_"), "tracktype-", track_type, "_traversal-", traversal, "_minlength-", min_length, ".nii.gz", ) streams_warp_png = '/tmp/dsn.png' # SyN FA->Template [mapping, affine_map, warped_fa] = regutils.wm_syn(t1w_brain, ap_path, dsn_dir) tractogram = load_tractogram( streams, fa_img, to_origin=Origin.NIFTI, to_space=Space.VOXMM, bbox_valid_check=False, ) fa_img.uncache() streamlines = tractogram.streamlines warped_fa_img = nib.load(warped_fa) warped_fa_affine = warped_fa_img.affine warped_fa_shape = warped_fa_img.shape streams_in_curr_grid = transform_streamlines(streamlines, affine_map.affine_inv) streams_final_filt = regutils.warp_streamlines(t1w_brain_img.affine, fa_img.affine, mapping, warped_fa_img, streams_in_curr_grid, brain_mask) # Remove streamlines with negative voxel indices lin_T, offset = _mapping_to_voxel(np.eye(4)) streams_final_filt_final = [] for sl in streams_final_filt: inds = np.dot(sl, lin_T) inds += offset if not inds.min().round(decimals=6) < 0: streams_final_filt_final.append(sl) # Save streamlines stf = StatefulTractogram( streams_final_filt_final, reference=t1w_brain_img, space=Space.VOXMM, origin=Origin.NIFTI, ) stf.remove_invalid_streamlines() streams_final_filt_final = stf.streamlines save_tractogram(stf, streams_t1w, bbox_valid_check=True) warped_fa_img.uncache() # DSN QC plotting show_template_bundles(streams_final_filt_final, atlas_t1w, streams_warp_png) nib.save( nib.Nifti1Image( density_map(streams_final_filt_final, affine=np.eye(4), vol_dims=warped_fa_shape), warped_fa_affine, ), density_t1w, ) del ( tractogram, streamlines, stf, streams_final_filt_final, streams_final_filt, streams_in_curr_grid, brain_mask, ) gc.collect() assert len(coords) == len(labels) atlas_for_streams = atlas_t1w else: print( "Skipping Direct Streamline Normalization (DSN). Will proceed to " "define fiber connectivity in native diffusion space...") streams_t1w = streams warped_fa = fa_path atlas_for_streams = labels_im_file return (streams_t1w, dir_path, track_type, conn_model, subnet, node_radius, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, parcellation, labels, coords, norm, binary, atlas_for_streams, traversal, warped_fa, min_length)
def tractography_estimation_data(dmri_estimation_data): path_tmp = tempfile.NamedTemporaryFile(mode='w+', suffix='.trk', delete=False) trk_path_tmp = str(path_tmp.name) dir_path = os.path.dirname(trk_path_tmp) gtab = dmri_estimation_data['gtab'] wm_img = nib.load(dmri_estimation_data['f_pve_wm']) dwi_img = nib.load(dmri_estimation_data['dwi_file']) dwi_data = dwi_img.get_fdata() B0_mask_img = nib.load(dmri_estimation_data['B0_mask']) mask_img = intersect_masks( [ nib.Nifti1Image(np.asarray( wm_img.dataobj).astype('bool').astype('int'), affine=wm_img.affine), nib.Nifti1Image(np.asarray( B0_mask_img.dataobj).astype('bool').astype('int'), affine=B0_mask_img.affine) ], threshold=1, connected=False, ) mask_data = mask_img.get_fdata() mask_file = fname_presuffix(dmri_estimation_data['B0_mask'], suffix="tracking_mask", use_ext=True) mask_img.to_filename(mask_file) csa_model = CsaOdfModel(gtab, sh_order=6) csa_peaks = peaks_from_model(csa_model, dwi_data, default_sphere, relative_peak_threshold=.8, min_separation_angle=45, mask=mask_data) stopping_criterion = BinaryStoppingCriterion(mask_data) seed_mask = (mask_data == 1) seeds = seeds_from_mask(seed_mask, dwi_img.affine, density=[1, 1, 1]) streamlines_generator = LocalTracking(csa_peaks, stopping_criterion, seeds, affine=dwi_img.affine, step_size=.5) streamlines = Streamlines(streamlines_generator) sft = StatefulTractogram(streamlines, B0_mask_img, origin=Origin.NIFTI, space=Space.VOXMM) sft.remove_invalid_streamlines() trk = f"{dir_path}/tractogram.trk" os.rename(trk_path_tmp, trk) save_tractogram(sft, trk, bbox_valid_check=False) del streamlines, sft, streamlines_generator, seeds, seed_mask, csa_peaks, \ csa_model, dwi_data, mask_data dwi_img.uncache() mask_img.uncache() gc.collect() yield {'trk': trk, 'mask': mask_file}
track_moving_warped = np.zeros([n_, N_points, 3]) for idx in range(n_): track_moving_warped[idx] = moving_warped[idx * N_points:N_points * (idx + 1)] else: track_moving_warped = track_moving.copy() for i, streamline in enumerate(track_moving): streamline_warp = warpNeigh.predict(streamline) track_moving_warped[i] += streamline_warp warped_filename = WarpedShot_dir + '/track_warped' + suffix + '.trk' track_moving_warped_sft = StatefulTractogram(track_moving_warped, FA_nib, Space.RASMM) idx_toremove, idx_tokeep = track_moving_warped_sft.remove_invalid_streamlines( ) save_tractogram(track_moving_warped_sft, warped_filename) print("save warped tracts as: " + warped_filename) #show_both_bundles((track_moving_warped,track_fixed,track_moving),colors=[window.colors.cyan,window.colors.green,window.colors.red],fname=ScreenShot_dir+'/after_Warp.png') #%% PLOT if plot_flag: import matplotlib.pyplot as plt #from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() #ax = fig.gca(projection='3d') it_arr = range(100000) #ax = fig.gca()