예제 #1
0
def reg_mri_pngs(
    mri, atlas, outdir, loc=0, mean=False, minthr=2, maxthr=95, edge=False
):
    """
    A function to create and save registered brain slice figures.
    
    Parameter
    ---------
    mri: nifti file
        the registered brain file generated in each registration step.
    atlas: nifti file
        the reference brain file used in each registration step.
    outdir: str
        directory where output png file is saved. 
    loc: int
        which dimension of the 4d brain data to use
    mean: bool
        whether to calculate the mean of the 4d brain data
        If False, the loc=0 dimension of the data (mri_data[:, :, :, loc]) is used
    minthr: int
        lower percentile threshold 
    maxthr: int
        upper percentile threshold
    """
    atlas_data = nb.load(atlas).get_data()
    mri_data = nb.load(mri).get_data()
    if mri_data.ndim == 4:  # 4d data, so we need to reduce a dimension
        if mean:
            mr_data = mri_data.mean(axis=3)
        else:
            mr_data = mri_data[:, :, :, loc]
    else:  # dim=3
        mr_data = mri_data

    cmap1 = LinearSegmentedColormap.from_list("mycmap1", ["white", "magenta"])
    cmap2 = LinearSegmentedColormap.from_list("mycmap2", ["white", "green"])

    fig = plot_overlays(atlas_data, mr_data, [cmap1, cmap2], minthr, maxthr, edge)
    # name and save the file
    fig.savefig(outdir + "/" + get_filename(mri) + "_2_" + get_filename(atlas) + ".png", format="png")
    
    plt.close()
예제 #2
0
def gen_overlay_pngs(brain,
                     original,
                     outdir,
                     loc=0,
                     mean=False,
                     minthr=2,
                     maxthr=95,
                     edge=False):
    """Generate a QA image for skullstrip.
    will call the function plot_overlays_skullstrip

    Parameters
    ----------
    brain: nifti file
        Path to the skull-stripped nifti brain
    original: nifti file
        Path to the original t1w brain, with the skull included
    outdir: str
        Path to the directory where QA will be saved
    loc: int
        which dimension of the 4d brain data to use
    mean: bool
        whether to calculate the mean of the 4d brain data
        If False, the loc=0 dimension of the data (mri_data[:, :, :, loc]) is used
    minthr: int
        lower percentile threshold
    maxthr: int
        upper percentile threshold
    edge: bool
        whether to use normalized luminance data
        If None, the respective min and max of the color array is used.
    """
    original_name = get_filename(original)
    brain_data = nb.load(brain).get_data()
    if brain_data.ndim == 4:  # 4d data, so we need to reduce a dimension
        if mean:
            brain_data = brain_data.mean(axis=3)
        else:
            brain_data = brain_data[:, :, :, loc]

    fig = plot_overlays_skullstrip(brain_data, original)

    # name and save the file
    fig.savefig(f"{outdir}/qa_skullstrip__{original_name}.png", format="png")
예제 #3
0
def extract_t1w_brain(t1w, out, tmpdir, skull="none"):
    """A function to extract the brain from an input T1w image
    using AFNI's brain extraction utilities.

    Parameters
    ----------
    t1w : str
        path for the input T1w image
    out : str
        path for the output brain image
    tmpdir : str
        Path for the temporary directory to store images
    skull : str, optional
        skullstrip parameter pre-set. Default is "none".
    """

    t1w_name = gen_utils.get_filename(t1w)
    # the t1w image with the skull removed.
    skull_t1w = f"{tmpdir}/{t1w_name}_noskull.nii.gz"
    # 3dskullstrip to extract the brain-only t1w
    t1w_skullstrip(t1w, skull_t1w, skull)
    # 3dcalc to apply the mask over the 4d image
    apply_mask(t1w, skull_t1w, out)
예제 #4
0
파일: register.py 프로젝트: CaseyWeiner/m2g
    def atlas2t1w2dwi_align(self, atlas, dsn=True):
        """alignment from atlas to t1w to dwi. A function to perform atlas alignmet. Tries nonlinear registration first, and if that fails, does a liner
        registration instead.
        Note: for this to work, must first have called t1w2dwi_align.
        Parameters
        ----------
        atlas : str
            path to atlas file you want to use
        dsn : bool, optional
            is your space for tractography native-dsn, by default True
        Returns
        -------
        str
            path to aligned atlas file
        """

        self.atlas = atlas
        self.atlas_name = gen_utils.get_filename(self.atlas)
        self.aligned_atlas_t1mni = (
            f"{self.reg_a}/{self.atlas_name}_aligned_atlas_t1w_mni.nii.gz")
        self.aligned_atlas_skull = (
            f"{self.reg_a}/{self.atlas_name}_aligned_atlas_skull.nii.gz")
        self.dwi_aligned_atlas = (
            f"{self.reg_anat}/{self.atlas_name}_aligned_atlas.nii.gz")

        reg_utils.align(
            self.atlas,
            self.t1_aligned_mni,
            init=None,
            xfm=None,
            out=self.aligned_atlas_t1mni,
            dof=12,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

        if (self.simple is False) and (dsn is False):
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                reg_utils.apply_warp(
                    self.t1w_brain,
                    self.aligned_atlas_t1mni,
                    self.aligned_atlas_skull,
                    warp=self.mni2t1w_warp,
                    interp="nn",
                    sup=True,
                )

                # Apply transform to dwi space
                reg_utils.align(
                    self.aligned_atlas_skull,
                    self.nodif_B0,
                    init=self.t1wtissue2dwi_xfm,
                    xfm=None,
                    out=self.dwi_aligned_atlas,
                    dof=6,
                    searchrad=True,
                    interp="nearestneighbour",
                    cost="mutualinfo",
                )
            except:
                print(
                    "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template registration."
                )
                # Create transform to align atlas to T1w using flirt
                reg_utils.align(
                    self.atlas,
                    self.t1w_brain,
                    xfm=self.xfm_atlas2t1w_init,
                    init=None,
                    bins=None,
                    dof=6,
                    cost="mutualinfo",
                    searchrad=True,
                    interp="spline",
                    out=None,
                    sch=None,
                )
                reg_utils.align(
                    self.atlas,
                    self.t1_aligned_mni,
                    xfm=self.xfm_atlas2t1w,
                    out=None,
                    dof=6,
                    searchrad=True,
                    bins=None,
                    interp="spline",
                    cost="mutualinfo",
                    init=self.xfm_atlas2t1w_init,
                )

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
                reg_utils.combine_xfms(self.xfm_atlas2t1w,
                                       self.t1wtissue2dwi_xfm,
                                       self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                reg_utils.applyxfm(self.nodif_B0, self.atlas,
                                   self.temp2dwi_xfm, self.dwi_aligned_atlas)
        elif dsn is False:
            # Create transform to align atlas to T1w using flirt
            reg_utils.align(
                self.atlas,
                self.t1w_brain,
                xfm=self.xfm_atlas2t1w_init,
                init=None,
                bins=None,
                dof=6,
                cost="mutualinfo",
                searchrad=None,
                interp="spline",
                out=None,
                sch=None,
            )
            reg_utils.align(
                self.atlas,
                self.t1w_brain,
                xfm=self.xfm_atlas2t1w,
                out=None,
                dof=6,
                searchrad=True,
                bins=None,
                interp="spline",
                cost="mutualinfo",
                init=self.xfm_atlas2t1w_init,
            )

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
            reg_utils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                                   self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            reg_utils.applyxfm(self.nodif_B0, self.atlas, self.temp2dwi_xfm,
                               self.dwi_aligned_atlas)
        else:
            pass

        # Set intensities to int
        if dsn is False:
            self.atlas_img = nib.load(self.dwi_aligned_atlas)
        else:
            self.atlas_img = nib.load(self.aligned_atlas_t1mni)
        self.atlas_data = np.around(self.atlas_img.get_data()).astype("int16")
        node_num = len(np.unique(self.atlas_data))
        self.atlas_data[self.atlas_data > node_num] = 0

        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img("img > 0", img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)

        if dsn is False:
            nib.save(
                nib.Nifti1Image(
                    self.atlas_data.astype(np.int32),
                    affine=self.atlas_img.affine,
                    header=self.atlas_img.header,
                ),
                self.dwi_aligned_atlas,
            )
            reg_mri_pngs(self.dwi_aligned_atlas, self.nodif_B0, self.qa_reg)
            return self.dwi_aligned_atlas
        else:
            nib.save(
                nib.Nifti1Image(
                    self.atlas_data.astype(np.int32),
                    affine=self.atlas_img.affine,
                    header=self.atlas_img.header,
                ),
                self.aligned_atlas_t1mni,
            )
            reg_mri_pngs(self.aligned_atlas_t1mni, self.t1_aligned_mni,
                         self.qa_reg)
            return self.aligned_atlas_t1mni
예제 #5
0
def m2g_dwi_worker(
    dwi,
    bvals,
    bvecs,
    t1w,
    atlas,
    mask,
    parcellations,
    outdir,
    vox_size="2mm",
    mod_type="det",
    track_type="local",
    mod_func="csa",
    seeds=20,
    reg_style="native",
    skipeddy=False,
    skipreg=False,
    skull=None,
):
    """Creates a brain graph from MRI data
    Parameters
    ----------
    dwi : str
        Path for the dwi file(s)
    bvals : str
        Path for the bval file(s)
    bvecs : str
        Path for the bvec file(s)
    t1w : str
        Location of anatomical input file(s)
    atlas : str
        Location of atlas file
    mask : str
        Location of T1w brain mask file, make sure the proper voxel size is used
    parcellations : list
        Filepaths to the parcellations we're using.
    outdir : str
        The directory where the output files should be stored. Prepopulate this folder with results of participants level analysis if running gorup analysis.
    vox_size : str
        Voxel size to use for template registrations. Default is '2mm'.
    mod_type : str
        Determinstic (det) or probabilistic (prob) tracking. Default is det.
    track_type : str
        Tracking approach: eudx or local. Default is eudx.
    mod_func : str
        Diffusion model: csd, csa. Default is csa.
    seeds : int
        Density of seeding for native-space tractography.
    reg_style : str
        Space for tractography. Default is native.
    skipeddy : bool
        Whether or not to skip the eddy correction if it has already been run. Default is False.
    skipreg : bool
        Whether or not to skip registration. Default is False.
    skull : str, optional
        skullstrip parameter pre-set. Default is "none".
    Raises
    ------
    ValueError
        Raised if downsampling voxel size is not supported
    ValueError
        Raised if bval/bvecs are potentially corrupted
    """

    # -------- Initial Setup ------------------ #
    # print starting arguments for clarity in log
    args = locals().copy()
    for arg, value in args.items():
        print(f"{arg} = {value}")

    # initial assertions
    if vox_size not in ["1mm", "2mm"]:
        raise ValueError("Voxel size not supported. Use 2mm or 1mm")

    print("Checking inputs...")
    for file_ in [t1w, bvals, bvecs, dwi, atlas, mask, *parcellations]:
        if not os.path.isfile(file_):
            raise FileNotFoundError(f"Input {file_} not found. Exiting m2g.")
        else:
            print(f"Input {file_} found.")

    # time m2g execution
    startTime = datetime.now()

    # initial variables
    outdir = Path(outdir)
    dwi_name = gen_utils.get_filename(dwi)

    # make output directory
    print("Adding directory tree...")
    parcellations = gen_utils.as_list(parcellations)
    gen_utils.make_initial_directories(outdir, parcellations=parcellations)

    # generate list of connectome file locations
    connectomes = []
    for parc in parcellations:
        name = gen_utils.get_filename(parc)
        folder = outdir / f"connectomes/{name}"
        connectome = f"{dwi_name}_{name}_connectome.csv"
        connectomes.append(str(folder / connectome))

    warm_welcome = welcome_message(connectomes)
    print(warm_welcome)

    # -------- Preprocessing Steps --------------------------------- #

    # set up directories
    prep_dwi: Path = outdir / "dwi/preproc"
    eddy_corrected_data: Path = prep_dwi / "eddy_corrected_data.nii.gz"

    # check that skipping eddy correct is possible
    if skipeddy:
        # do it anyway if eddy_corrected_data doesnt exist
        if not eddy_corrected_data.is_file():
            print("Cannot skip preprocessing if it has not already been run!")
            skipeddy = False

    # if we're not skipping eddy correct, perform it
    if not skipeddy:
        prep_dwi = gen_utils.as_directory(prep_dwi, remove=True, return_as_path=True)
        preproc.eddy_correct(dwi, str(eddy_corrected_data), 0)

    # copy bval/bvec files to output directory
    bvec_scaled = str(outdir / "dwi/preproc/bvec_scaled.bvec")
    fbval = str(outdir / "dwi/preproc/bval.bval")
    fbvec = str(outdir / "dwi/preproc/bvec.bvec")
    shutil.copyfile(bvecs, fbvec)
    shutil.copyfile(bvals, fbval)

    # Correct any corrupted bvecs/bvals
    bvals, bvecs = read_bvals_bvecs(fbval, fbvec)
    bvecs[np.where(np.any(abs(bvecs) >= 10, axis=1) == True)] = [1, 0, 0]
    bvecs[np.where(bvals == 0)] = 0
    bvecs_0_loc = np.all(abs(bvecs) == np.array([0, 0, 0]), axis=1)
    if len(bvecs[np.where(np.logical_and(bvals > 50, bvecs_0_loc))]) > 0:
        raise ValueError(
            "WARNING: Encountered potentially corrupted bval/bvecs. Check to ensure volumes with a "
            "diffusion weighting are not being treated as B0's along the bvecs"
        )
    np.savetxt(fbval, bvals)
    np.savetxt(fbvec, bvecs)

    # Rescale bvecs
    print("Rescaling b-vectors...")
    preproc.rescale_bvec(fbvec, bvec_scaled)

    # Check orientation (eddy_corrected_data)
    eddy_corrected_data, bvecs = gen_utils.reorient_dwi(
        eddy_corrected_data, bvec_scaled, prep_dwi
    )

    # Check dimensions
    eddy_corrected_data = gen_utils.match_target_vox_res(
        eddy_corrected_data, vox_size, outdir, sens="dwi"
    )

    # Build gradient table
    print("fbval: ", fbval)
    print("bvecs: ", bvecs)
    print("fbvec: ", fbvec)
    print("eddy_corrected_data: ", eddy_corrected_data)
    gtab, nodif_B0, nodif_B0_mask = gen_utils.make_gtab_and_bmask(
        fbval, fbvec, eddy_corrected_data, prep_dwi
    )

    # Get B0 header and affine
    eddy_corrected_data_img = nib.load(str(eddy_corrected_data))
    hdr = eddy_corrected_data_img.header

    # -------- Registration Steps ----------------------------------- #

    # define registration directory locations
    # TODO: possibly just pull these from a container generated by `gen_utils.make_initial_directories`
    reg_dirs = ["anat/preproc", "anat/registered", "tmp/reg_a", "tmp/reg_m"]
    reg_dirs = [outdir / loc for loc in reg_dirs]
    prep_anat, reg_anat, tmp_rega, tmp_regm = reg_dirs

    if not skipreg:
        for dir_ in [prep_anat, reg_anat]:
            if gen_utils.has_files(dir_):
                gen_utils.as_directory(dir_, remove=True)
        if gen_utils.has_files(tmp_rega) or gen_utils.has_files(tmp_regm):
            for tmp in [tmp_regm, tmp_rega]:
                gen_utils.as_directory(tmp, remove=True)

    # Check orientation (t1w)
    start_time = time.time()
    t1w = gen_utils.reorient_t1w(t1w, prep_anat)
    t1w = gen_utils.match_target_vox_res(t1w, vox_size, outdir, sens="anat")

    print("Running registration in native space...")

    # Instantiate registration
    reg = register.DmriReg(
        outdir, nodif_B0, nodif_B0_mask, t1w, vox_size, skull, simple=False
    )

    # Perform anatomical segmentation
    if skipreg and os.path.isfile(reg.wm_edge):
        print("Found existing gentissue run!")
    else:
        reg.gen_tissue()

    # Align t1w to dwi
    t1w2dwi_align_files = [reg.t1w2dwi, reg.mni2t1w_warp, reg.t1_aligned_mni]
    existing_files = all(map(os.path.isfile, t1w2dwi_align_files))
    if skipreg and existing_files:
        print("Found existing t1w2dwi run!")
    else:
        reg.t1w2dwi_align()

    # Align tissue classifiers
    tissue_align_files = [
        reg.wm_gm_int_in_dwi,
        reg.vent_csf_in_dwi,
        reg.corpuscallosum_dwi,
    ]
    existing_files = all(map(os.path.isfile, tissue_align_files))
    if skipreg and existing_files:
        print("Found existing tissue2dwi run!")
    else:
        reg.tissue2dwi_align()

    # Align atlas to dwi-space and check that the atlas hasn't lost any of the rois
    _ = [reg, parcellations, outdir, prep_anat, vox_size, reg_style]
    labels_im_file_list = reg_utils.skullstrip_check(*_)

    # -------- Tensor Fitting and Fiber Tractography ---------------- #

    # initial path setup
    prep_track: Path = outdir / "dwi/fiber"
    start_time = time.time()
    qa_tensor = str(outdir / "qa/tensor/Tractography_Model_Peak_Directions.png")

    # build seeds
    seeds = track.build_seed_list(reg.wm_gm_int_in_dwi, np.eye(4), dens=int(seeds))
    print("Using " + str(len(seeds)) + " seeds...")

    # Compute direction model and track fiber streamlines
    print("Beginning tractography in native space...")
    # TODO: could add a --skiptrack flag here that checks if `streamlines.trk` already exists to skip to connectome estimation more quickly
    trct = track.RunTrack(
        eddy_corrected_data,
        nodif_B0_mask,
        reg.gm_in_dwi,
        reg.vent_csf_in_dwi,
        reg.csf_mask_dwi,
        reg.wm_in_dwi,
        gtab,
        mod_type,
        track_type,
        mod_func,
        qa_tensor,
        seeds,
        np.eye(4),
    )
    streamlines = trct.run()
    trk_hdr = trct.make_hdr(streamlines, hdr)
    tractogram = nib.streamlines.Tractogram(
        streamlines, affine_to_rasmm=trk_hdr["voxel_to_rasmm"]
    )
    trkfile = nib.streamlines.trk.TrkFile(tractogram, header=trk_hdr)
    streams = os.path.join(prep_track, "streamlines.trk")
    nib.streamlines.save(trkfile, streams)

    print("Streamlines complete")
    print(f"Tractography runtime: {np.round(time.time() - start_time, 1)}")

    if reg_style == "native_dsn":
        fa_path = track.tens_mod_fa_est(gtab, eddy_corrected_data, nodif_B0_mask)
        # Normalize streamlines
        print("Running DSN...")
        streamlines_mni, streams_mni = register.direct_streamline_norm(
            streams, fa_path, outdir
        )
        # Save streamlines to disk
        print("Saving DSN-registered streamlines: " + streams_mni)

    # ------- Connectome Estimation --------------------------------- #
    # Generate graphs from streamlines for each parcellation

    for idx, parc in enumerate(parcellations):
        print(f"Generating graph for {parc} parcellation...")
        print(f"Applying native-space alignment to {parcellations[idx]}")
        if reg_style == "native":
            tracks = streamlines
        elif reg_style == "native_dsn":
            tracks = streamlines_mni
        rois = labels_im_file_list[idx]
        labels_im = nib.load(labels_im_file_list[idx])
        attr = len(np.unique(np.around(labels_im.get_data()).astype("int16"))) - 1
        g1 = graph.GraphTools(
            attr=parcellations[idx],
            rois=rois,
            tracks=tracks,
            affine=np.eye(4),
            outdir=outdir,
            connectome_path=connectomes[idx],
        )
        g1.g = g1.make_graph()
        g1.summary()
        g1.save_graph_png(connectomes[idx])
        g1.save_graph(connectomes[idx])

    exe_time = datetime.now() - startTime

    if "M2G_URL" in os.environ:
        print("Note: tractography QA does not work in a Docker environment.")
    else:
        qa_tractography_out = outdir / "qa/fibers"
        qa_tractography(streams, str(qa_tractography_out), str(eddy_corrected_data))
        print("QA tractography Completed.")

    print(f"Total execution time: {exe_time}")
    print("M2G Complete.")
    print(f"Output contents: {os.listdir(outdir / f'connectomes')}")
    print("~~~~~~~~~~~~~~\n\n")
    print(
        "NOTE :: m2g uses native-space registration to generate connectomes.\n Without post-hoc normalization, multiple connectomes generated with m2g cannot be compared directly."
    )