Exemple #1
0
def atlas2t1w2dwi_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    mni2t1w_warp,
    t1_aligned_mni,
    ap_path,
    t1w2dwi_bbr_xfm,
    mni2t1_xfm,
    t1w2dwi_xfm,
    wm_gm_int_in_dwi,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    dwi_aligned_atlas,
    dwi_aligned_atlas_wmgm_int,
    B0_mask,
    simple,
):
    """
    A function to perform atlas alignment atlas --> T1 --> dwi.
    Tries nonlinear registration first, and if that fails, does a linear registration instead. For this to succeed,
    must first have called t1w2dwi_align.
    """
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive
    from pynets.registration import reg_utils as regutils
    from nilearn.image import math_img
    from nilearn.masking import intersect_masks

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("int32"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

            # Apply linear transformation from template to dwi space
            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

        regutils.align(
            aligned_atlas_skull,
            ap_path,
            init=t1w2dwi_xfm,
            out=dwi_aligned_atlas,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    atlas_img = nib.load(dwi_aligned_atlas)
    wm_gm_img = nib.load(wm_gm_int_in_dwi)
    wm_gm_mask_img = math_img("img > 0", img=wm_gm_img)
    atlas_mask_img = math_img("img > 0", img=atlas_img)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )

    dwi_aligned_atlas_wmgm_int_img = intersect_masks(
        [wm_gm_mask_img, atlas_mask_img], threshold=0, connected=False)

    nib.save(atlas_img_corr, dwi_aligned_atlas)
    nib.save(dwi_aligned_atlas_wmgm_int_img, dwi_aligned_atlas_wmgm_int)

    os.system(
        f"fslmaths {dwi_aligned_atlas} -mas {B0_mask} {dwi_aligned_atlas} "
        f"2>/dev/null")

    os.system(f"fslmaths {dwi_aligned_atlas_wmgm_int} -mas {B0_mask} "
              f"{dwi_aligned_atlas_wmgm_int} 2>/dev/null")

    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        print("Warning! Non-consecutive integers found in parcellation...")

    atlas_img.uncache()
    atlas_img_corr.uncache()
    atlas_mask_img.uncache()
    wm_gm_img.uncache()
    wm_gm_mask_img.uncache()

    return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Exemple #2
0
def atlas2t1w2dwi_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    mni2t1w_warp,
    t1_aligned_mni,
    ap_path,
    mni2t1_xfm,
    t1w2dwi_xfm,
    wm_gm_int_in_dwi,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    dwi_aligned_atlas,
    dwi_aligned_atlas_wmgm_int,
    B0_mask,
    mni2dwi_xfm,
    simple,
):
    """
    A function to perform atlas alignment atlas --> T1 --> dwi.
    Tries nonlinear registration first, and if that fails, does a linear
    registration instead. For this to succeed, must first have called
    t1w2dwi_align.
    """
    import time
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive
    from pynets.registration import utils as regutils
    from nilearn.image import math_img
    from nilearn.masking import intersect_masks

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        atlas_img_orig = nib.load(uatlas_parcels)
    else:
        atlas_img_orig = nib.load(uatlas)

    old_count = len(np.unique(np.asarray(atlas_img_orig.dataobj)))

    uatlas_res_template = resample_to_img(atlas_img_orig,
                                          template_img,
                                          interpolation="nearest")

    uatlas_res_template = nib.Nifti1Image(
        np.asarray(uatlas_res_template.dataobj).astype('uint16'),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )
            time.sleep(0.5)

            # Apply linear transformation from template to dwi space
            regutils.applyxfm(ap_path,
                              aligned_atlas_skull,
                              t1w2dwi_xfm,
                              dwi_aligned_atlas,
                              interp="nearestneighbour")
            time.sleep(0.5)
        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low"
                " quality,\nusing linear template registration.")

            regutils.applyxfm(t1w_brain,
                              aligned_atlas_t1mni,
                              mni2t1_xfm,
                              aligned_atlas_skull,
                              interp="nearestneighbour")
            time.sleep(0.5)
            combine_xfms(mni2t1_xfm, t1w2dwi_xfm, mni2dwi_xfm)
            time.sleep(0.5)
            regutils.applyxfm(ap_path,
                              aligned_atlas_t1mni,
                              mni2dwi_xfm,
                              dwi_aligned_atlas,
                              interp="nearestneighbour")
            time.sleep(0.5)
    else:
        regutils.applyxfm(t1w_brain,
                          aligned_atlas_t1mni,
                          mni2t1_xfm,
                          aligned_atlas_skull,
                          interp="nearestneighbour")
        time.sleep(0.5)
        combine_xfms(mni2t1_xfm, t1w2dwi_xfm, mni2dwi_xfm)
        time.sleep(0.5)
        regutils.applyxfm(ap_path,
                          aligned_atlas_t1mni,
                          mni2dwi_xfm,
                          dwi_aligned_atlas,
                          interp="nearestneighbour")
        time.sleep(0.5)

    atlas_img = nib.load(dwi_aligned_atlas)
    wm_gm_img = nib.load(wm_gm_int_in_dwi)
    wm_gm_mask_img = math_img("img > 0", img=wm_gm_img)
    atlas_mask_img = math_img("img > 0", img=atlas_img)

    atlas_img_corr = nib.Nifti1Image(
        np.asarray(atlas_img.dataobj).astype('uint16'),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )

    # Get the union of masks
    dwi_aligned_atlas_wmgm_int_img = intersect_masks(
        [wm_gm_mask_img, atlas_mask_img], threshold=0, connected=False)

    nib.save(atlas_img_corr, dwi_aligned_atlas)
    nib.save(dwi_aligned_atlas_wmgm_int_img, dwi_aligned_atlas_wmgm_int)

    dwi_aligned_atlas = regutils.apply_mask_to_image(dwi_aligned_atlas,
                                                     B0_mask,
                                                     dwi_aligned_atlas)

    time.sleep(0.5)

    dwi_aligned_atlas_wmgm_int = regutils.apply_mask_to_image(
        dwi_aligned_atlas_wmgm_int, B0_mask, dwi_aligned_atlas_wmgm_int)

    time.sleep(0.5)
    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        print("Warning! Non-consecutive integers found in parcellation...")

    new_count = len(unique_a)
    diff = np.abs(np.int(float(new_count) - float(old_count)))
    print(f"Previous label count: {old_count}")
    print(f"New label count: {new_count}")
    print(f"Labels dropped: {diff}")

    atlas_img.uncache()
    atlas_img_corr.uncache()
    atlas_img_orig.uncache()
    atlas_mask_img.uncache()
    wm_gm_img.uncache()
    wm_gm_mask_img.uncache()

    return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_skull
Exemple #3
0
def atlas2t1w_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    t1_aligned_mni,
    mni2t1w_warp,
    mni2t1_xfm,
    gm_mask,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    aligned_atlas_gm,
    simple,
):
    """
    A function to perform atlas alignment from atlas --> T1w.
    """
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint16"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    os.system(f"fslmaths {aligned_atlas_skull} -mas {gm_mask} "
              f"{aligned_atlas_gm} 2>/dev/null")

    atlas_img = nib.load(aligned_atlas_gm)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0
    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )
    nib.save(atlas_img_corr, aligned_atlas_gm)
    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        old_count = len(np.unique(uatlas_res_template_data))
        new_count = len(unique_a)
        diff = np.abs(np.int(float(new_count) - float(old_count)))
        print("\nWarning! Non-consecutive integers found in parcellation...")
        print(f"Previous label count: {old_count}")
        print(f"New label count: {new_count}")
        print(f"Labels dropped: {diff}")
        if diff > 1:
            print('Grey-Matter mask too restrictive for this parcellation. '
                  'Falling back to the T1w mask...')
            os.system(f"fslmaths {aligned_atlas_skull} -mas {t1w_brain_mask} "
                      f"{aligned_atlas_gm} 2>/dev/null")
    template_img.uncache()

    return aligned_atlas_gm, aligned_atlas_skull