def reg_mri_pngs( mri, atlas, outdir, loc=0, mean=False, minthr=2, maxthr=95, edge=False ): """ A function to create and save registered brain slice figures. Parameter --------- mri: nifti file the registered brain file generated in each registration step. atlas: nifti file the reference brain file used in each registration step. outdir: str directory where output png file is saved. loc: int which dimension of the 4d brain data to use mean: bool whether to calculate the mean of the 4d brain data If False, the loc=0 dimension of the data (mri_data[:, :, :, loc]) is used minthr: int lower percentile threshold maxthr: int upper percentile threshold """ atlas_data = nb.load(atlas).get_data() mri_data = nb.load(mri).get_data() if mri_data.ndim == 4: # 4d data, so we need to reduce a dimension if mean: mr_data = mri_data.mean(axis=3) else: mr_data = mri_data[:, :, :, loc] else: # dim=3 mr_data = mri_data cmap1 = LinearSegmentedColormap.from_list("mycmap1", ["white", "magenta"]) cmap2 = LinearSegmentedColormap.from_list("mycmap2", ["white", "green"]) fig = plot_overlays(atlas_data, mr_data, [cmap1, cmap2], minthr, maxthr, edge) # name and save the file fig.savefig(outdir + "/" + get_filename(mri) + "_2_" + get_filename(atlas) + ".png", format="png") plt.close()
def gen_overlay_pngs(brain, original, outdir, loc=0, mean=False, minthr=2, maxthr=95, edge=False): """Generate a QA image for skullstrip. will call the function plot_overlays_skullstrip Parameters ---------- brain: nifti file Path to the skull-stripped nifti brain original: nifti file Path to the original t1w brain, with the skull included outdir: str Path to the directory where QA will be saved loc: int which dimension of the 4d brain data to use mean: bool whether to calculate the mean of the 4d brain data If False, the loc=0 dimension of the data (mri_data[:, :, :, loc]) is used minthr: int lower percentile threshold maxthr: int upper percentile threshold edge: bool whether to use normalized luminance data If None, the respective min and max of the color array is used. """ original_name = get_filename(original) brain_data = nb.load(brain).get_data() if brain_data.ndim == 4: # 4d data, so we need to reduce a dimension if mean: brain_data = brain_data.mean(axis=3) else: brain_data = brain_data[:, :, :, loc] fig = plot_overlays_skullstrip(brain_data, original) # name and save the file fig.savefig(f"{outdir}/qa_skullstrip__{original_name}.png", format="png")
def extract_t1w_brain(t1w, out, tmpdir, skull="none"): """A function to extract the brain from an input T1w image using AFNI's brain extraction utilities. Parameters ---------- t1w : str path for the input T1w image out : str path for the output brain image tmpdir : str Path for the temporary directory to store images skull : str, optional skullstrip parameter pre-set. Default is "none". """ t1w_name = gen_utils.get_filename(t1w) # the t1w image with the skull removed. skull_t1w = f"{tmpdir}/{t1w_name}_noskull.nii.gz" # 3dskullstrip to extract the brain-only t1w t1w_skullstrip(t1w, skull_t1w, skull) # 3dcalc to apply the mask over the 4d image apply_mask(t1w, skull_t1w, out)
def atlas2t1w2dwi_align(self, atlas, dsn=True): """alignment from atlas to t1w to dwi. A function to perform atlas alignmet. Tries nonlinear registration first, and if that fails, does a liner registration instead. Note: for this to work, must first have called t1w2dwi_align. Parameters ---------- atlas : str path to atlas file you want to use dsn : bool, optional is your space for tractography native-dsn, by default True Returns ------- str path to aligned atlas file """ self.atlas = atlas self.atlas_name = gen_utils.get_filename(self.atlas) self.aligned_atlas_t1mni = ( f"{self.reg_a}/{self.atlas_name}_aligned_atlas_t1w_mni.nii.gz") self.aligned_atlas_skull = ( f"{self.reg_a}/{self.atlas_name}_aligned_atlas_skull.nii.gz") self.dwi_aligned_atlas = ( f"{self.reg_anat}/{self.atlas_name}_aligned_atlas.nii.gz") reg_utils.align( self.atlas, self.t1_aligned_mni, init=None, xfm=None, out=self.aligned_atlas_t1mni, dof=12, searchrad=True, interp="nearestneighbour", cost="mutualinfo", ) if (self.simple is False) and (dsn is False): try: # Apply warp resulting from the inverse of T1w-->MNI created earlier reg_utils.apply_warp( self.t1w_brain, self.aligned_atlas_t1mni, self.aligned_atlas_skull, warp=self.mni2t1w_warp, interp="nn", sup=True, ) # Apply transform to dwi space reg_utils.align( self.aligned_atlas_skull, self.nodif_B0, init=self.t1wtissue2dwi_xfm, xfm=None, out=self.dwi_aligned_atlas, dof=6, searchrad=True, interp="nearestneighbour", cost="mutualinfo", ) except: print( "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template registration." ) # Create transform to align atlas to T1w using flirt reg_utils.align( self.atlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6, cost="mutualinfo", searchrad=True, interp="spline", out=None, sch=None, ) reg_utils.align( self.atlas, self.t1_aligned_mni, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True, bins=None, interp="spline", cost="mutualinfo", init=self.xfm_atlas2t1w_init, ) # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi reg_utils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm) # Apply linear transformation from template to dwi space reg_utils.applyxfm(self.nodif_B0, self.atlas, self.temp2dwi_xfm, self.dwi_aligned_atlas) elif dsn is False: # Create transform to align atlas to T1w using flirt reg_utils.align( self.atlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6, cost="mutualinfo", searchrad=None, interp="spline", out=None, sch=None, ) reg_utils.align( self.atlas, self.t1w_brain, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True, bins=None, interp="spline", cost="mutualinfo", init=self.xfm_atlas2t1w_init, ) # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi reg_utils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm) # Apply linear transformation from template to dwi space reg_utils.applyxfm(self.nodif_B0, self.atlas, self.temp2dwi_xfm, self.dwi_aligned_atlas) else: pass # Set intensities to int if dsn is False: self.atlas_img = nib.load(self.dwi_aligned_atlas) else: self.atlas_img = nib.load(self.aligned_atlas_t1mni) self.atlas_data = np.around(self.atlas_img.get_data()).astype("int16") node_num = len(np.unique(self.atlas_data)) self.atlas_data[self.atlas_data > node_num] = 0 t_img = load_img(self.wm_gm_int_in_dwi) mask = math_img("img > 0", img=t_img) mask.to_filename(self.wm_gm_int_in_dwi_bin) if dsn is False: nib.save( nib.Nifti1Image( self.atlas_data.astype(np.int32), affine=self.atlas_img.affine, header=self.atlas_img.header, ), self.dwi_aligned_atlas, ) reg_mri_pngs(self.dwi_aligned_atlas, self.nodif_B0, self.qa_reg) return self.dwi_aligned_atlas else: nib.save( nib.Nifti1Image( self.atlas_data.astype(np.int32), affine=self.atlas_img.affine, header=self.atlas_img.header, ), self.aligned_atlas_t1mni, ) reg_mri_pngs(self.aligned_atlas_t1mni, self.t1_aligned_mni, self.qa_reg) return self.aligned_atlas_t1mni
def m2g_dwi_worker( dwi, bvals, bvecs, t1w, atlas, mask, parcellations, outdir, vox_size="2mm", mod_type="det", track_type="local", mod_func="csa", seeds=20, reg_style="native", skipeddy=False, skipreg=False, skull=None, ): """Creates a brain graph from MRI data Parameters ---------- dwi : str Path for the dwi file(s) bvals : str Path for the bval file(s) bvecs : str Path for the bvec file(s) t1w : str Location of anatomical input file(s) atlas : str Location of atlas file mask : str Location of T1w brain mask file, make sure the proper voxel size is used parcellations : list Filepaths to the parcellations we're using. outdir : str The directory where the output files should be stored. Prepopulate this folder with results of participants level analysis if running gorup analysis. vox_size : str Voxel size to use for template registrations. Default is '2mm'. mod_type : str Determinstic (det) or probabilistic (prob) tracking. Default is det. track_type : str Tracking approach: eudx or local. Default is eudx. mod_func : str Diffusion model: csd, csa. Default is csa. seeds : int Density of seeding for native-space tractography. reg_style : str Space for tractography. Default is native. skipeddy : bool Whether or not to skip the eddy correction if it has already been run. Default is False. skipreg : bool Whether or not to skip registration. Default is False. skull : str, optional skullstrip parameter pre-set. Default is "none". Raises ------ ValueError Raised if downsampling voxel size is not supported ValueError Raised if bval/bvecs are potentially corrupted """ # -------- Initial Setup ------------------ # # print starting arguments for clarity in log args = locals().copy() for arg, value in args.items(): print(f"{arg} = {value}") # initial assertions if vox_size not in ["1mm", "2mm"]: raise ValueError("Voxel size not supported. Use 2mm or 1mm") print("Checking inputs...") for file_ in [t1w, bvals, bvecs, dwi, atlas, mask, *parcellations]: if not os.path.isfile(file_): raise FileNotFoundError(f"Input {file_} not found. Exiting m2g.") else: print(f"Input {file_} found.") # time m2g execution startTime = datetime.now() # initial variables outdir = Path(outdir) dwi_name = gen_utils.get_filename(dwi) # make output directory print("Adding directory tree...") parcellations = gen_utils.as_list(parcellations) gen_utils.make_initial_directories(outdir, parcellations=parcellations) # generate list of connectome file locations connectomes = [] for parc in parcellations: name = gen_utils.get_filename(parc) folder = outdir / f"connectomes/{name}" connectome = f"{dwi_name}_{name}_connectome.csv" connectomes.append(str(folder / connectome)) warm_welcome = welcome_message(connectomes) print(warm_welcome) # -------- Preprocessing Steps --------------------------------- # # set up directories prep_dwi: Path = outdir / "dwi/preproc" eddy_corrected_data: Path = prep_dwi / "eddy_corrected_data.nii.gz" # check that skipping eddy correct is possible if skipeddy: # do it anyway if eddy_corrected_data doesnt exist if not eddy_corrected_data.is_file(): print("Cannot skip preprocessing if it has not already been run!") skipeddy = False # if we're not skipping eddy correct, perform it if not skipeddy: prep_dwi = gen_utils.as_directory(prep_dwi, remove=True, return_as_path=True) preproc.eddy_correct(dwi, str(eddy_corrected_data), 0) # copy bval/bvec files to output directory bvec_scaled = str(outdir / "dwi/preproc/bvec_scaled.bvec") fbval = str(outdir / "dwi/preproc/bval.bval") fbvec = str(outdir / "dwi/preproc/bvec.bvec") shutil.copyfile(bvecs, fbvec) shutil.copyfile(bvals, fbval) # Correct any corrupted bvecs/bvals bvals, bvecs = read_bvals_bvecs(fbval, fbvec) bvecs[np.where(np.any(abs(bvecs) >= 10, axis=1) == True)] = [1, 0, 0] bvecs[np.where(bvals == 0)] = 0 bvecs_0_loc = np.all(abs(bvecs) == np.array([0, 0, 0]), axis=1) if len(bvecs[np.where(np.logical_and(bvals > 50, bvecs_0_loc))]) > 0: raise ValueError( "WARNING: Encountered potentially corrupted bval/bvecs. Check to ensure volumes with a " "diffusion weighting are not being treated as B0's along the bvecs" ) np.savetxt(fbval, bvals) np.savetxt(fbvec, bvecs) # Rescale bvecs print("Rescaling b-vectors...") preproc.rescale_bvec(fbvec, bvec_scaled) # Check orientation (eddy_corrected_data) eddy_corrected_data, bvecs = gen_utils.reorient_dwi( eddy_corrected_data, bvec_scaled, prep_dwi ) # Check dimensions eddy_corrected_data = gen_utils.match_target_vox_res( eddy_corrected_data, vox_size, outdir, sens="dwi" ) # Build gradient table print("fbval: ", fbval) print("bvecs: ", bvecs) print("fbvec: ", fbvec) print("eddy_corrected_data: ", eddy_corrected_data) gtab, nodif_B0, nodif_B0_mask = gen_utils.make_gtab_and_bmask( fbval, fbvec, eddy_corrected_data, prep_dwi ) # Get B0 header and affine eddy_corrected_data_img = nib.load(str(eddy_corrected_data)) hdr = eddy_corrected_data_img.header # -------- Registration Steps ----------------------------------- # # define registration directory locations # TODO: possibly just pull these from a container generated by `gen_utils.make_initial_directories` reg_dirs = ["anat/preproc", "anat/registered", "tmp/reg_a", "tmp/reg_m"] reg_dirs = [outdir / loc for loc in reg_dirs] prep_anat, reg_anat, tmp_rega, tmp_regm = reg_dirs if not skipreg: for dir_ in [prep_anat, reg_anat]: if gen_utils.has_files(dir_): gen_utils.as_directory(dir_, remove=True) if gen_utils.has_files(tmp_rega) or gen_utils.has_files(tmp_regm): for tmp in [tmp_regm, tmp_rega]: gen_utils.as_directory(tmp, remove=True) # Check orientation (t1w) start_time = time.time() t1w = gen_utils.reorient_t1w(t1w, prep_anat) t1w = gen_utils.match_target_vox_res(t1w, vox_size, outdir, sens="anat") print("Running registration in native space...") # Instantiate registration reg = register.DmriReg( outdir, nodif_B0, nodif_B0_mask, t1w, vox_size, skull, simple=False ) # Perform anatomical segmentation if skipreg and os.path.isfile(reg.wm_edge): print("Found existing gentissue run!") else: reg.gen_tissue() # Align t1w to dwi t1w2dwi_align_files = [reg.t1w2dwi, reg.mni2t1w_warp, reg.t1_aligned_mni] existing_files = all(map(os.path.isfile, t1w2dwi_align_files)) if skipreg and existing_files: print("Found existing t1w2dwi run!") else: reg.t1w2dwi_align() # Align tissue classifiers tissue_align_files = [ reg.wm_gm_int_in_dwi, reg.vent_csf_in_dwi, reg.corpuscallosum_dwi, ] existing_files = all(map(os.path.isfile, tissue_align_files)) if skipreg and existing_files: print("Found existing tissue2dwi run!") else: reg.tissue2dwi_align() # Align atlas to dwi-space and check that the atlas hasn't lost any of the rois _ = [reg, parcellations, outdir, prep_anat, vox_size, reg_style] labels_im_file_list = reg_utils.skullstrip_check(*_) # -------- Tensor Fitting and Fiber Tractography ---------------- # # initial path setup prep_track: Path = outdir / "dwi/fiber" start_time = time.time() qa_tensor = str(outdir / "qa/tensor/Tractography_Model_Peak_Directions.png") # build seeds seeds = track.build_seed_list(reg.wm_gm_int_in_dwi, np.eye(4), dens=int(seeds)) print("Using " + str(len(seeds)) + " seeds...") # Compute direction model and track fiber streamlines print("Beginning tractography in native space...") # TODO: could add a --skiptrack flag here that checks if `streamlines.trk` already exists to skip to connectome estimation more quickly trct = track.RunTrack( eddy_corrected_data, nodif_B0_mask, reg.gm_in_dwi, reg.vent_csf_in_dwi, reg.csf_mask_dwi, reg.wm_in_dwi, gtab, mod_type, track_type, mod_func, qa_tensor, seeds, np.eye(4), ) streamlines = trct.run() trk_hdr = trct.make_hdr(streamlines, hdr) tractogram = nib.streamlines.Tractogram( streamlines, affine_to_rasmm=trk_hdr["voxel_to_rasmm"] ) trkfile = nib.streamlines.trk.TrkFile(tractogram, header=trk_hdr) streams = os.path.join(prep_track, "streamlines.trk") nib.streamlines.save(trkfile, streams) print("Streamlines complete") print(f"Tractography runtime: {np.round(time.time() - start_time, 1)}") if reg_style == "native_dsn": fa_path = track.tens_mod_fa_est(gtab, eddy_corrected_data, nodif_B0_mask) # Normalize streamlines print("Running DSN...") streamlines_mni, streams_mni = register.direct_streamline_norm( streams, fa_path, outdir ) # Save streamlines to disk print("Saving DSN-registered streamlines: " + streams_mni) # ------- Connectome Estimation --------------------------------- # # Generate graphs from streamlines for each parcellation for idx, parc in enumerate(parcellations): print(f"Generating graph for {parc} parcellation...") print(f"Applying native-space alignment to {parcellations[idx]}") if reg_style == "native": tracks = streamlines elif reg_style == "native_dsn": tracks = streamlines_mni rois = labels_im_file_list[idx] labels_im = nib.load(labels_im_file_list[idx]) attr = len(np.unique(np.around(labels_im.get_data()).astype("int16"))) - 1 g1 = graph.GraphTools( attr=parcellations[idx], rois=rois, tracks=tracks, affine=np.eye(4), outdir=outdir, connectome_path=connectomes[idx], ) g1.g = g1.make_graph() g1.summary() g1.save_graph_png(connectomes[idx]) g1.save_graph(connectomes[idx]) exe_time = datetime.now() - startTime if "M2G_URL" in os.environ: print("Note: tractography QA does not work in a Docker environment.") else: qa_tractography_out = outdir / "qa/fibers" qa_tractography(streams, str(qa_tractography_out), str(eddy_corrected_data)) print("QA tractography Completed.") print(f"Total execution time: {exe_time}") print("M2G Complete.") print(f"Output contents: {os.listdir(outdir / f'connectomes')}") print("~~~~~~~~~~~~~~\n\n") print( "NOTE :: m2g uses native-space registration to generate connectomes.\n Without post-hoc normalization, multiple connectomes generated with m2g cannot be compared directly." )