Esempio n. 1
0
def test_apply_warp():
    import pkg_resources
    # Warp original anat to standard space using warp img (had to invwarp first) and linear mats
    base_dir = str(Path(__file__).parent / "examples")
    anat_dir = f"{base_dir}/003/anat"
    ref = pkg_resources.resource_filename(
        "pynets", f"templates/MNI152_T1_brain_2mm.nii.gz")
    inp = f"{anat_dir}/sub-003_T1w.nii.gz"
    out = f"{anat_dir}/highres2standard_test_apply_warp.nii.gz"
    warp = f"{anat_dir}/highres2standard_warp.nii.gz"
    xfm = f"{anat_dir}/highres2standard.mat"

    reg_utils.apply_warp(ref,
                         inp,
                         out,
                         warp,
                         xfm=xfm,
                         mask=None,
                         interp=None,
                         sup=False)
    # highres2standard_apply_warp = f"{anat_dir}/highres2standard_test_apply_warp.nii.gz"
    # highres2standard_apply_warp = nib.load(highres2standard_apply_warp)
    # highres2standard_apply_warp = highres2standard_apply_warp.get_data()
    #
    # highres2standard_align_nonlinear = nib.load(f"{anat_dir}/highres2standard_nonlinear.nii.gz")
    # highres2standard_align_nonlinear = highres2standard_align_nonlinear.get_data()
    # check_eq_arrays = np.allclose(highres2standard_apply_warp.astype('float32'),
    #                                  highres2standard_align_nonlinear.astype('float32'))
    # assert check_eq_arrays is True
    assert os.path.isfile(out)
Esempio n. 2
0
    def atlas2t1wmni_align(self, uatlas, atlas):
        """
        A function to perform atlas alignment from atlas --> T1_MNI.
        Tries nonlinear registration first, and if that fails, does a linear registration instead.
        """
        aligned_atlas_t1mni = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni.nii.gz")
        aligned_atlas_skull = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_skull.nii.gz")
        gm_mask_mni = "%s%s%s%s" % (self.anat_path, '/', self.t1w_name, "_gm_mask_t1w_mni.nii.gz")
        aligned_atlas_t1mni_gm = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni_gm.nii.gz")

        regutils.align(uatlas, self.t1_aligned_mni, init=None, xfm=None, out=aligned_atlas_t1mni, dof=12,
                       searchrad=True, interp="nearestneighbour", cost='mutualinfo')

        # Apply warp resulting from the inverse of T1w-->MNI created earlier
        regutils.apply_warp(self.t1w_brain, aligned_atlas_t1mni, aligned_atlas_skull, warp=self.mni2t1w_warp,
                            interp='nn', sup=True)

        # Apply warp resulting from the inverse MNI->T1w created earlier
        regutils.apply_warp(self.t1w_brain, self.gm_mask_thr, gm_mask_mni, warp=self.mni2t1w_warp, interp='nn',
                            sup=True)

        # Set intensities to int
        atlas_img = nib.load(aligned_atlas_t1mni)
        atlas_data = atlas_img.get_fdata().astype('int')
        nib.save(nib.Nifti1Image(atlas_data.astype(np.int32), affine=atlas_img.affine,
                                 header=atlas_img.header), aligned_atlas_t1mni)
        os.system("fslmaths {} -mas {} {}".format(aligned_atlas_t1mni, gm_mask_mni, aligned_atlas_t1mni_gm))

        return aligned_atlas_t1mni_gm
Esempio n. 3
0
def waymask2dwi_align(
    waymask,
    t1w_brain,
    ap_path,
    mni2t1w_warp,
    mni2t1_xfm,
    t1wtissue2dwi_xfm,
    waymask_in_t1w,
    waymask_in_dwi,
    simple,
):
    """
    A function to perform alignment of a waymask from MNI space --> T1w --> dwi.
    """
    from pynets.registration import reg_utils as regutils

    # Apply warp or transformer resulting from the inverse MNI->T1w created
    # earlier
    if simple is False:
        regutils.apply_warp(t1w_brain,
                            waymask,
                            waymask_in_t1w,
                            warp=mni2t1w_warp)
    else:
        regutils.applyxfm(t1w_brain, waymask, mni2t1_xfm, waymask_in_t1w)

    # Apply transform from t1w to native dwi space
    regutils.applyxfm(ap_path, waymask_in_t1w, t1wtissue2dwi_xfm,
                      waymask_in_dwi)

    return waymask_in_dwi
Esempio n. 4
0
def roi2t1w_align(roi, t1w_brain, mni2t1_xfm, mni2t1w_warp, roi_in_t1w,
                  template, simple):
    """
    A function to perform alignment of a roi from MNI space --> T1w.
    """
    import time
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img

    roi_img = nib.load(roi)
    template_img = nib.load(template)

    roi_img_res = resample_to_img(roi_img,
                                  template_img,
                                  interpolation="nearest")
    roi_res = f"{roi.split('.nii')[0]}_res.nii.gz"
    nib.save(roi_img_res, roi_res)

    # Apply warp or transformer resulting from the inverse MNI->T1w created
    # earlier
    if simple is False:
        regutils.apply_warp(t1w_brain, roi_res, roi_in_t1w, warp=mni2t1w_warp)
    else:
        regutils.applyxfm(t1w_brain, roi_res, mni2t1_xfm, roi_in_t1w)

    time.sleep(0.5)

    return roi_in_t1w
Esempio n. 5
0
def waymask2dwi_align(
    waymask,
    t1w_brain,
    ap_path,
    mni2t1w_warp,
    mni2t1_xfm,
    t1wtissue2dwi_xfm,
    waymask_in_t1w,
    waymask_in_dwi,
    template,
    simple,
):
    """
    A function to perform alignment of a waymask from
    MNI space --> T1w --> dwi.
    """
    import time
    from nilearn.image import math_img
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img

    # Apply warp or transformer resulting from the inverse MNI->T1w created
    # earlier
    waymask_img = nib.load(waymask)
    template_img = nib.load(template)

    waymask_img_res = resample_to_img(
        waymask_img,
        template_img,
    )
    waymask_res = f"{waymask.split('.nii')[0]}_res.nii.gz"
    nib.save(waymask_img_res, waymask_res)

    if simple is False:
        regutils.apply_warp(t1w_brain,
                            waymask_res,
                            waymask_in_t1w,
                            warp=mni2t1w_warp)
    else:
        regutils.applyxfm(t1w_brain, waymask_res, mni2t1_xfm, waymask_in_t1w)

    time.sleep(0.5)
    # Apply transform from t1w to native dwi space
    regutils.applyxfm(ap_path, waymask_in_t1w, t1wtissue2dwi_xfm,
                      waymask_in_dwi)

    time.sleep(0.5)

    t_img = nib.load(waymask_in_dwi)
    mask = math_img("img > 0.01", img=t_img)
    mask.to_filename(waymask_in_dwi)

    return waymask_in_dwi
Esempio n. 6
0
def roi2t1w_align(roi, t1w_brain, mni2t1_xfm, mni2t1w_warp, roi_in_t1w,
                  simple):
    """
    A function to perform alignment of a roi from MNI space --> T1w.
    """
    from pynets.registration import reg_utils as regutils

    # Apply warp or transformer resulting from the inverse MNI->T1w created
    # earlier
    if simple is False:
        regutils.apply_warp(t1w_brain, roi, roi_in_t1w, warp=mni2t1w_warp)
    else:
        regutils.applyxfm(t1w_brain, roi, mni2t1_xfm, roi_in_t1w)

    return roi_in_t1w
Esempio n. 7
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI
        space --> dwi and CSF from T1w space --> dwi. First generates and
        performs dwi space alignment of avoidance/waypoint masks for
        tractography. First creates ventricle ROI. Then creates transforms
        from stock MNI template to dwi space. For this to succeed, must first
        have called both t1w2dwi_align.
        """
        import sys
        import time
        import os.path as op

        # Register Lateral Ventricles and Corpus Callosum rois to t1w
        if not op.isfile(self.mni_atlas):
            raise FileNotFoundError("FSL atlas for ventricle reference not"
                                    " found!")

        # Create transform to MNI atlas to T1w using flirt. This will be use to
        # transform the ventricles to dwi space.
        regutils.align(
            self.mni_atlas,
            self.input_mni_brain,
            xfm=self.xfm_roi2mni_init,
            init=None,
            bins=None,
            dof=6,
            cost="mutualinfo",
            searchrad=True,
            interp="spline",
            out=None,
        )
        time.sleep(0.5)

        if sys.platform.startswith('win') is False:
            try:
                nib.load(self.mni_vent_loc)
            except indexed_gzip.ZranError as e:
                print(e,
                      f"\nCannot load ventricle ROI. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)
        else:
            try:
                nib.load(self.mni_vent_loc)
            except ImportError as e:
                print(e, f"\nCannot load ventricle ROI. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(
            self.input_mni_brain,
            self.mni_vent_loc,
            self.xfm_roi2mni_init,
            self.vent_mask_mni,
        )
        time.sleep(0.5)
        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(
                self.t1w_brain,
                self.vent_mask_mni,
                self.vent_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)

            if sys.platform.startswith('win') is False:
                try:
                    nib.load(self.corpuscallosum)
                except indexed_gzip.ZranError as e:
                    print(e,
                          f"\nCannot load Corpus Callosum ROI. "
                          f"Do you have git-lfs installed?")
                    sys.exit(1)
            else:
                try:
                    nib.load(self.corpuscallosum)
                except ImportError as e:
                    print(e, f"\nCannot load Corpus Callosum ROI. "
                          f"Do you have git-lfs installed?")
                    sys.exit(1)

            regutils.apply_warp(
                self.t1w_brain,
                self.corpuscallosum,
                self.corpuscallosum_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )

        else:
            regutils.applyxfm(
                self.vent_mask_mni,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.vent_mask_t1w)
            time.sleep(0.5)
            regutils.applyxfm(
                self.corpuscallosum,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.corpuscallosum_mask_t1w,
            )
            time.sleep(0.5)

        # Applyxfm tissue maps to dwi space
        if self.t1w_brain_mask is not None:
            regutils.applyxfm(
                self.ap_path,
                self.t1w_brain_mask,
                self.t1wtissue2dwi_xfm,
                self.t1w_brain_mask_in_dwi,
            )
            time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path,
            self.vent_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.vent_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path,
            self.csf_mask,
            self.t1wtissue2dwi_xfm,
            self.csf_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path, self.gm_mask, self.t1wtissue2dwi_xfm, self.gm_in_dwi
        )
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path, self.wm_mask, self.t1wtissue2dwi_xfm, self.wm_in_dwi
        )
        time.sleep(0.5)

        regutils.applyxfm(
            self.ap_path,
            self.corpuscallosum_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.corpuscallosum_dwi,
        )
        time.sleep(0.5)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img = math_img("img > 0.10", img=thr_img)
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img = math_img("img > 0.15", img=thr_img)
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img = math_img("img > 0.95", img=thr_img)
        nib.save(thr_img, self.csf_mask_dwi_bin)

        # Threshold WM to binary in dwi space
        self.wm_in_dwi = regutils.apply_mask_to_image(self.wm_in_dwi,
                                                      self.wm_in_dwi_bin,
                                                      self.wm_in_dwi)
        time.sleep(0.5)
        # Threshold GM to binary in dwi space
        self.gm_in_dwi = regutils.apply_mask_to_image(self.gm_in_dwi,
                                                      self.gm_in_dwi_bin,
                                                      self.gm_in_dwi)
        time.sleep(0.5)
        # Threshold CSF to binary in dwi space
        self.csf_mask = regutils.apply_mask_to_image(self.csf_mask_dwi,
                                                     self.csf_mask_dwi_bin,
                                                     self.csf_mask_dwi)
        time.sleep(0.5)
        # Create ventricular CSF mask
        print("Creating Ventricular CSF mask...")
        os.system(
            f"fslmaths {self.vent_mask_dwi} -kernel sphere 10 -ero "
            f"-bin {self.vent_mask_dwi}"
        )
        time.sleep(1)
        os.system(
            f"fslmaths {self.csf_mask_dwi} -add {self.vent_mask_dwi} "
            f"-bin {self.vent_csf_in_dwi}"
        )
        time.sleep(1)
        print("Creating Corpus Callosum mask...")
        os.system(
            f"fslmaths {self.corpuscallosum_dwi} -mas {self.wm_in_dwi_bin} "
            f"-sub {self.vent_csf_in_dwi} "
            f"-bin {self.corpuscallosum_dwi}")
        time.sleep(1)
        # Create gm-wm interface image
        os.system(
            f"fslmaths {self.gm_in_dwi_bin} -mul {self.wm_in_dwi_bin} "
            f"-add {self.corpuscallosum_dwi} "
            f"-mas {self.B0_mask} -bin {self.wm_gm_int_in_dwi}")
        time.sleep(1)
        return
Esempio n. 8
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI space --> dwi and CSF from T1w space --> dwi.
        First generates and performs dwi space alignment of avoidance/waypoint masks for tractography.
        First creates ventricle ROI. Then creates transforms from stock MNI template to dwi space.
        For this to succeed, must first have called both t1w2dwi_align and atlas2t1w2dwi_align.
        """

        # Create MNI-space ventricle mask
        print('Creating MNI-space ventricle ROI...')
        if not os.path.isfile(self.mni_atlas):
            raise ValueError('FSL atlas for ventricle reference not found!')
        os.system("fslroi {} {} 2 1".format(self.mni_atlas, self.rvent_out_file))
        os.system("fslroi {} {} 13 1".format(self.mni_atlas, self.lvent_out_file))
        os.system("fslmaths {} -add {} -thr 0.1 -bin {}".format(self.lvent_out_file, self.rvent_out_file,
                                                                self.mni_vent_loc))

        # Create transform to MNI atlas to T1w using flirt. This will be use to transform the ventricles to dwi space.
        regutils.align(self.mni_atlas, self.input_mni_brain, xfm=self.xfm_roi2mni_init, init=None, bins=None, dof=6,
                       cost='mutualinfo', searchrad=True, interp="spline", out=None)

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(self.input_mni_brain, self.mni_vent_loc, self.xfm_roi2mni_init, self.vent_mask_mni)

        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(self.t1w_brain, self.vent_mask_mni, self.vent_mask_t1w, warp=self.mni2t1w_warp,
                                interp='nn', sup=True)

        # Applyxfm tissue maps to dwi space
        regutils.applyxfm(self.fa_path, self.vent_mask_t1w, self.t1wtissue2dwi_xfm, self.vent_mask_dwi)
        regutils.applyxfm(self.fa_path, self.csf_mask, self.t1wtissue2dwi_xfm, self.csf_mask_dwi)
        regutils.applyxfm(self.fa_path, self.gm_mask, self.t1wtissue2dwi_xfm, self.gm_in_dwi)
        regutils.applyxfm(self.fa_path, self.wm_mask, self.t1wtissue2dwi_xfm, self.wm_in_dwi)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.9] = 0
        nib.save(thr_img, self.csf_mask_dwi)

        # Threshold WM to binary in dwi space
        t_img = load_img(self.wm_in_dwi_bin)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        t_img = load_img(self.gm_in_dwi_bin)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        t_img = load_img(self.csf_mask_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.csf_mask_dwi_bin)

        # Create ventricular CSF mask
        print('Creating ventricular CSF mask...')
        os.system("fslmaths {} -kernel sphere 10 -ero -bin {}".format(self.vent_mask_dwi, self.vent_mask_dwi))
        os.system("fslmaths {} -add {} -bin {} ".format(self.csf_mask_dwi, self.vent_mask_dwi, self.vent_csf_in_dwi))

        # Create gm-wm interface image
        os.system("fslmaths {} -mul {} -mas {} -bin {}".format(self.gm_in_dwi_bin, self.wm_in_dwi_bin, self.B0_mask,
                                                               self.wm_gm_int_in_dwi))

        return
Esempio n. 9
0
    def atlas2t1w2dwi_align(self, uatlas, atlas):
        """
        A function to perform atlas alignment atlas --> T1 --> dwi.
        Tries nonlinear registration first, and if that fails, does a linear registration instead. For this to succeed,
        must first have called t1w2dwi_align.
        """
        aligned_atlas_t1mni = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni.nii.gz")
        aligned_atlas_skull = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_skull.nii.gz")
        dwi_aligned_atlas = "%s%s%s%s" % (self.reg_path_img, '/', atlas, "_dwi_track.nii.gz")
        dwi_aligned_atlas_wmgm_int = "%s%s%s%s" % (self.reg_path_img, '/', atlas, "_dwi_track_wmgm_int.nii.gz")

        regutils.align(uatlas, self.t1_aligned_mni, init=None, xfm=None, out=aligned_atlas_t1mni, dof=12,
                       searchrad=True, interp="nearestneighbour", cost='mutualinfo')

        if self.simple is False:
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                regutils.apply_warp(self.t1w_brain, aligned_atlas_t1mni, aligned_atlas_skull,
                                    warp=self.mni2t1w_warp, interp='nn', sup=True)

                # Apply transform to dwi space
                regutils.align(aligned_atlas_skull, self.fa_path, init=self.t1wtissue2dwi_xfm, xfm=None,
                               out=dwi_aligned_atlas, dof=6, searchrad=True, interp="nearestneighbour",
                               cost='mutualinfo')
            except:
                print("Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                      "registration.")

                # Create transform to align atlas to T1w using flirt
                regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6,
                               cost='mutualinfo', searchrad=True, interp="spline", out=None, sch=None)
                regutils.align(uatlas, self.t1_aligned_mni, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True,
                               bins=None, interp="spline", cost='mutualinfo', init=self.xfm_atlas2t1w_init)

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a
                # transform from atlas ->(-> t1w ->)-> dwi
                regutils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                regutils.applyxfm(self.fa_path, uatlas, self.temp2dwi_xfm, dwi_aligned_atlas)
        else:
            # Create transform to align atlas to T1w using flirt
            regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6,
                           cost='mutualinfo', searchrad=None, interp="spline", out=None, sch=None)
            regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True,
                           bins=None, interp="spline", cost='mutualinfo', init=self.xfm_atlas2t1w_init)

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a
            # transform from atlas ->(-> t1w ->)-> dwi
            regutils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            regutils.applyxfm(self.fa_path, uatlas, self.temp2dwi_xfm, dwi_aligned_atlas)

        # Set intensities to int
        atlas_img = nib.load(dwi_aligned_atlas)
        atlas_data = atlas_img.get_fdata().astype('int')
        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)
        nib.save(nib.Nifti1Image(atlas_data.astype(np.int32), affine=atlas_img.affine, header=atlas_img.header),
                 dwi_aligned_atlas)
        os.system("fslmaths {} -mas {} -mas {} {}".format(dwi_aligned_atlas, self.B0_mask, self.wm_gm_int_in_dwi_bin,
                                                          dwi_aligned_atlas_wmgm_int))

        return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Esempio n. 10
0
def atlas2t1w2dwi_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    mni2t1w_warp,
    t1_aligned_mni,
    ap_path,
    t1w2dwi_bbr_xfm,
    mni2t1_xfm,
    t1w2dwi_xfm,
    wm_gm_int_in_dwi,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    dwi_aligned_atlas,
    dwi_aligned_atlas_wmgm_int,
    B0_mask,
    simple,
):
    """
    A function to perform atlas alignment atlas --> T1 --> dwi.
    Tries nonlinear registration first, and if that fails, does a linear registration instead. For this to succeed,
    must first have called t1w2dwi_align.
    """
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive
    from pynets.registration import reg_utils as regutils
    from nilearn.image import math_img
    from nilearn.masking import intersect_masks

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("int32"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

            # Apply linear transformation from template to dwi space
            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

        regutils.align(
            aligned_atlas_skull,
            ap_path,
            init=t1w2dwi_xfm,
            out=dwi_aligned_atlas,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    atlas_img = nib.load(dwi_aligned_atlas)
    wm_gm_img = nib.load(wm_gm_int_in_dwi)
    wm_gm_mask_img = math_img("img > 0", img=wm_gm_img)
    atlas_mask_img = math_img("img > 0", img=atlas_img)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )

    dwi_aligned_atlas_wmgm_int_img = intersect_masks(
        [wm_gm_mask_img, atlas_mask_img], threshold=0, connected=False)

    nib.save(atlas_img_corr, dwi_aligned_atlas)
    nib.save(dwi_aligned_atlas_wmgm_int_img, dwi_aligned_atlas_wmgm_int)

    os.system(
        f"fslmaths {dwi_aligned_atlas} -mas {B0_mask} {dwi_aligned_atlas} "
        f"2>/dev/null")

    os.system(f"fslmaths {dwi_aligned_atlas_wmgm_int} -mas {B0_mask} "
              f"{dwi_aligned_atlas_wmgm_int} 2>/dev/null")

    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        print("Warning! Non-consecutive integers found in parcellation...")

    atlas_img.uncache()
    atlas_img_corr.uncache()
    atlas_mask_img.uncache()
    wm_gm_img.uncache()
    wm_gm_mask_img.uncache()

    return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Esempio n. 11
0
def atlas2t1w_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    t1_aligned_mni,
    mni2t1w_warp,
    mni2t1_xfm,
    gm_mask,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    aligned_atlas_gm,
    simple,
):
    """
    A function to perform atlas alignment from atlas --> T1w.
    """
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint16"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    os.system(f"fslmaths {aligned_atlas_skull} -mas {gm_mask} "
              f"{aligned_atlas_gm} 2>/dev/null")

    atlas_img = nib.load(aligned_atlas_gm)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0
    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )
    nib.save(atlas_img_corr, aligned_atlas_gm)
    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        old_count = len(np.unique(uatlas_res_template_data))
        new_count = len(unique_a)
        diff = np.abs(np.int(float(new_count) - float(old_count)))
        print("\nWarning! Non-consecutive integers found in parcellation...")
        print(f"Previous label count: {old_count}")
        print(f"New label count: {new_count}")
        print(f"Labels dropped: {diff}")
        if diff > 1:
            print('Grey-Matter mask too restrictive for this parcellation. '
                  'Falling back to the T1w mask...')
            os.system(f"fslmaths {aligned_atlas_skull} -mas {t1w_brain_mask} "
                      f"{aligned_atlas_gm} 2>/dev/null")
    template_img.uncache()

    return aligned_atlas_gm, aligned_atlas_skull
Esempio n. 12
0
def atlas2t1w2dwi_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    mni2t1w_warp,
    t1_aligned_mni,
    ap_path,
    t1w2dwi_bbr_xfm,
    mni2t1_xfm,
    t1w2dwi_xfm,
    wm_gm_int_in_dwi,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    dwi_aligned_atlas,
    dwi_aligned_atlas_wmgm_int,
    B0_mask,
    mni2dwi_xfm,
    simple,
):
    """
    A function to perform atlas alignment atlas --> T1 --> dwi.
    Tries nonlinear registration first, and if that fails, does a linear
    registration instead. For this to succeed, must first have called
    t1w2dwi_align.
    """
    import time
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive
    from pynets.registration import reg_utils as regutils
    from nilearn.image import math_img
    from nilearn.masking import intersect_masks

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        atlas_img_orig = nib.load(uatlas_parcels)
    else:
        atlas_img_orig = nib.load(uatlas)

    old_count = len(np.unique(np.asarray(atlas_img_orig.dataobj)))

    uatlas_res_template = resample_to_img(atlas_img_orig,
                                          template_img,
                                          interpolation="nearest")

    uatlas_res_template = nib.Nifti1Image(
        np.asarray(uatlas_res_template.dataobj).astype('uint16'),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )
            time.sleep(0.5)

            # Apply linear transformation from template to dwi space
            regutils.applyxfm(ap_path,
                              aligned_atlas_skull,
                              t1w2dwi_bbr_xfm,
                              dwi_aligned_atlas,
                              interp="nearestneighbour")
            time.sleep(0.5)
        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low"
                " quality,\nusing linear template registration.")

            combine_xfms(mni2t1_xfm, t1w2dwi_bbr_xfm, mni2dwi_xfm)
            time.sleep(0.5)
            regutils.applyxfm(ap_path,
                              aligned_atlas_t1mni,
                              mni2dwi_xfm,
                              dwi_aligned_atlas,
                              interp="nearestneighbour")
            time.sleep(0.5)
    else:
        combine_xfms(mni2t1_xfm, t1w2dwi_xfm, mni2dwi_xfm)
        time.sleep(0.5)
        regutils.applyxfm(ap_path,
                          aligned_atlas_t1mni,
                          mni2dwi_xfm,
                          dwi_aligned_atlas,
                          interp="nearestneighbour")
        time.sleep(0.5)

    atlas_img = nib.load(dwi_aligned_atlas)
    wm_gm_img = nib.load(wm_gm_int_in_dwi)
    wm_gm_mask_img = math_img("img > 0", img=wm_gm_img)
    atlas_mask_img = math_img("img > 0", img=atlas_img)

    atlas_img_corr = nib.Nifti1Image(
        np.asarray(atlas_img.dataobj).astype('uint16'),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )

    # Get the union of masks
    dwi_aligned_atlas_wmgm_int_img = intersect_masks(
        [wm_gm_mask_img, atlas_mask_img], threshold=0, connected=False)

    nib.save(atlas_img_corr, dwi_aligned_atlas)
    nib.save(dwi_aligned_atlas_wmgm_int_img, dwi_aligned_atlas_wmgm_int)

    dwi_aligned_atlas = regutils.apply_mask_to_image(dwi_aligned_atlas,
                                                     B0_mask,
                                                     dwi_aligned_atlas)

    time.sleep(0.5)

    dwi_aligned_atlas_wmgm_int = regutils.apply_mask_to_image(
        dwi_aligned_atlas_wmgm_int, B0_mask, dwi_aligned_atlas_wmgm_int)

    time.sleep(0.5)
    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        print("Warning! Non-consecutive integers found in parcellation...")

    new_count = len(unique_a)
    diff = np.abs(np.int(float(new_count) - float(old_count)))
    print(f"Previous label count: {old_count}")
    print(f"New label count: {new_count}")
    print(f"Labels dropped: {diff}")

    atlas_img.uncache()
    atlas_img_corr.uncache()
    atlas_img_orig.uncache()
    atlas_mask_img.uncache()
    wm_gm_img.uncache()
    wm_gm_mask_img.uncache()

    return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Esempio n. 13
0
def atlas2t1w_align(uatlas,
                    uatlas_parcels,
                    atlas,
                    t1w_brain,
                    t1w_brain_mask,
                    t1_aligned_mni,
                    mni2t1w_warp,
                    mni2t1_xfm,
                    gm_mask,
                    aligned_atlas_t1mni,
                    aligned_atlas_skull,
                    aligned_atlas_gm,
                    simple,
                    gm_fail_tol=5):
    """
    A function to perform atlas alignment from atlas --> T1w.
    """
    import time
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img
    # from pynets.core.utils import checkConsecutive

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        atlas_img_orig = nib.load(uatlas_parcels)
    else:
        atlas_img_orig = nib.load(uatlas)

    # old_count = len(np.unique(np.asarray(atlas_img_orig.dataobj)))

    uatlas_res_template = resample_to_img(atlas_img_orig,
                                          template_img,
                                          interpolation="nearest")

    uatlas_res_template = nib.Nifti1Image(
        np.asarray(uatlas_res_template.dataobj).astype('uint16'),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )
            time.sleep(0.5)
        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low "
                "quality,\nusing linear template registration.")

            regutils.applyxfm(t1w_brain,
                              aligned_atlas_t1mni,
                              mni2t1_xfm,
                              aligned_atlas_skull,
                              interp="nearestneighbour")
            time.sleep(0.5)
    else:
        regutils.applyxfm(t1w_brain,
                          aligned_atlas_t1mni,
                          mni2t1_xfm,
                          aligned_atlas_skull,
                          interp="nearestneighbour")
        time.sleep(0.5)

    # aligned_atlas_gm = regutils.apply_mask_to_image(aligned_atlas_skull,
    #                                                 gm_mask,
    #                                                 aligned_atlas_gm)
    aligned_atlas_gm = regutils.apply_mask_to_image(aligned_atlas_skull,
                                                    t1w_brain_mask,
                                                    aligned_atlas_gm)

    time.sleep(0.5)
    atlas_img = nib.load(aligned_atlas_gm)

    atlas_img_corr = nib.Nifti1Image(
        np.asarray(atlas_img.dataobj).astype('uint16'),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )
    nib.save(atlas_img_corr, aligned_atlas_gm)
    # final_dat = atlas_img_corr.get_fdata()
    # unique_a = sorted(set(np.array(final_dat.flatten().tolist())))
    #
    # if not checkConsecutive(unique_a):
    #     print("\nWarning! non-consecutive integers found in parcellation...")
    # new_count = len(unique_a)
    # diff = np.abs(np.int(float(new_count) - float(old_count)))
    # print(f"Previous label count: {old_count}")
    # print(f"New label count: {new_count}")
    # print(f"Labels dropped: {diff}")
    # if diff > gm_fail_tol:
    #     print(f"Grey-Matter mask too restrictive >{str(gm_fail_tol)} for this "
    #           f"parcellation. Falling back to the T1w mask...")
    #     aligned_atlas_gm = regutils.apply_mask_to_image(aligned_atlas_skull,
    #                                                     t1w_brain_mask,
    #                                                     aligned_atlas_gm)
    #     time.sleep(5)
    template_img.uncache()
    atlas_img_orig.uncache()
    atlas_img.uncache()
    atlas_img_corr.uncache()

    return aligned_atlas_gm, aligned_atlas_skull
Esempio n. 14
0
def RegisterParcellation2MNIFunc_align(uatlas, template, template_mask,
                                       t1w_brain, t1w2mni_xfm,
                                       aligned_atlas_t1w, aligned_atlas_mni,
                                       t1w2mni_warp, simple):
    """
    A function to perform atlas alignment from T1w atlas --> MNI.
    """
    import time
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img

    atlas_img = nib.load(uatlas)
    t1w_brain_img = nib.load(t1w_brain)

    uatlas_res_template = resample_to_img(atlas_img,
                                          t1w_brain_img,
                                          interpolation="nearest")

    uatlas_res_template = nib.Nifti1Image(
        np.asarray(uatlas_res_template.dataobj).astype('uint16'),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1w)

    if simple is False:
        try:
            regutils.apply_warp(
                template,
                aligned_atlas_t1w,
                aligned_atlas_mni,
                warp=t1w2mni_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)
        except BaseException:
            print("Warning: Atlas is not in correct dimensions, or input is "
                  "low quality,\nusing linear template registration.")

            regutils.align(
                aligned_atlas_t1w,
                template,
                init=t1w2mni_xfm,
                out=aligned_atlas_mni,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )
            time.sleep(0.5)
    else:
        regutils.align(
            aligned_atlas_t1w,
            template,
            init=t1w2mni_xfm,
            out=aligned_atlas_mni,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )
        time.sleep(0.5)
    return aligned_atlas_mni
Esempio n. 15
0
    def tissue2dwi_align(self):
        """
        alignment of ventricle ROI's from MNI space --> dwi and
        CSF from T1w space --> dwi
        A function to generate and perform dwi space alignment of avoidance/waypoint masks for tractography.
        First creates ventricle ROI. Then creates transforms from stock MNI template to dwi space.
        NOTE: for this to work, must first have called both t1w2dwi_align and atlas2t1w2dwi_align.
        """

        # Create MNI-space ventricle mask
        print('Creating MNI-space ventricle ROI...')
        if not os.path.isfile(self.mni_atlas):
            raise ValueError('FSL atlas for ventricle reference not found!')
        cmd = 'fslroi ' + self.mni_atlas + ' ' + self.rvent_out_file + ' 2 1'
        os.system(cmd)
        cmd = 'fslroi ' + self.mni_atlas + ' ' + self.lvent_out_file + ' 13 1'
        os.system(cmd)
        self.args = "%s%s%s" % (' -add ', self.rvent_out_file,
                                ' -thr 0.1 -bin ')
        cmd = 'fslmaths ' + self.lvent_out_file + self.args + self.mni_vent_loc
        os.system(cmd)

        # Create transform to MNI atlas to T1w using flirt. This will be use to transform the ventricles to dwi space.
        mgru.align(self.mni_atlas,
                   self.input_mni_brain,
                   xfm=self.xfm_roi2mni_init,
                   init=None,
                   bins=None,
                   dof=6,
                   cost='mutualinfo',
                   searchrad=True,
                   interp="spline",
                   out=None)

        # Create transform to align roi to mni and T1w using flirt
        mgru.applyxfm(self.input_mni_brain, self.mni_vent_loc,
                      self.xfm_roi2mni_init, self.vent_mask_mni)

        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            mgru.apply_warp(self.t1w_brain,
                            self.vent_mask_mni,
                            self.vent_mask_t1w,
                            warp=self.mni2t1w_warp,
                            interp='nn',
                            sup=True)

        # Applyxfm tissue maps to dwi space
        mgru.applyxfm(self.fa_path, self.vent_mask_t1w, self.t1wtissue2dwi_xfm,
                      self.vent_mask_dwi)
        mgru.applyxfm(self.fa_path, self.csf_mask, self.t1wtissue2dwi_xfm,
                      self.csf_mask_dwi)
        mgru.applyxfm(self.fa_path, self.gm_mask, self.t1wtissue2dwi_xfm,
                      self.gm_in_dwi)
        mgru.applyxfm(self.fa_path, self.wm_mask, self.t1wtissue2dwi_xfm,
                      self.wm_in_dwi)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.9] = 0
        nib.save(thr_img, self.csf_mask_dwi)

        # Threshold WM to binary in dwi space
        self.t_img = load_img(self.wm_in_dwi_bin)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        self.t_img = load_img(self.gm_in_dwi_bin)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        self.t_img = load_img(self.csf_mask_dwi)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.csf_mask_dwi_bin)

        # Create ventricular CSF mask
        print('Creating ventricular CSF mask...')
        cmd = 'fslmaths ' + self.vent_mask_dwi + ' -kernel sphere 10 -ero -bin ' + self.vent_mask_dwi
        os.system(cmd)
        cmd = 'fslmaths ' + self.csf_mask_dwi + ' -add ' + self.vent_mask_dwi + ' -bin ' + self.vent_csf_in_dwi
        os.system(cmd)

        # Create gm-wm interface image
        cmd = 'fslmaths ' + self.gm_in_dwi_bin + ' -mul ' + self.wm_in_dwi_bin + ' -mas ' + self.nodif_B0_mask + ' -bin ' + self.wm_gm_int_in_dwi
        os.system(cmd)

        return
Esempio n. 16
0
    def atlas2t1w2dwi_align(self, uatlas_select, atlas_select):
        """
        alignment from atlas --> T1 --> dwi
        A function to perform atlas alignment.
        Tries nonlinear registration first, and if that fails,
        does a linear registration instead.
        NOTE: for this to work, must first have called t1w2dwi_align.
        """
        self.atlas = uatlas_select
        self.atlas_name = atlas_select
        self.aligned_atlas_t1mni = "{}/{}_t1w_mni.nii.gz".format(
            self.basedir_path, self.atlas_name)
        self.aligned_atlas_skull = "{}/{}_t1w_skull.nii.gz".format(
            self.anat_path, self.atlas_name)
        self.dwi_aligned_atlas = "{}/{}_dwi_track.nii.gz".format(
            self.reg_path_img, self.atlas_name)
        self.dwi_aligned_atlas_wmgm_int = "{}/{}_dwi_track_wmgm_int.nii.gz".format(
            self.reg_path_img, self.atlas_name)

        mgru.align(self.atlas,
                   self.t1_aligned_mni,
                   init=None,
                   xfm=None,
                   out=self.aligned_atlas_t1mni,
                   dof=12,
                   searchrad=True,
                   interp="nearestneighbour",
                   cost='mutualinfo')

        if self.simple is False:
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                mgru.apply_warp(self.t1w_brain,
                                self.aligned_atlas_t1mni,
                                self.aligned_atlas_skull,
                                warp=self.mni2t1w_warp,
                                interp='nn',
                                sup=True)

                # Apply transform to dwi space
                mgru.align(self.aligned_atlas_skull,
                           self.fa_path,
                           init=self.t1wtissue2dwi_xfm,
                           xfm=None,
                           out=self.dwi_aligned_atlas,
                           dof=6,
                           searchrad=True,
                           interp="nearestneighbour",
                           cost='mutualinfo')
            except:
                print(
                    "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template registration."
                )

                # Create transform to align atlas to T1w using flirt
                mgru.align(self.atlas,
                           self.t1w_brain,
                           xfm=self.xfm_atlas2t1w_init,
                           init=None,
                           bins=None,
                           dof=6,
                           cost='mutualinfo',
                           searchrad=True,
                           interp="spline",
                           out=None,
                           sch=None)
                mgru.align(self.atlas,
                           self.t1_aligned_mni,
                           xfm=self.xfm_atlas2t1w,
                           out=None,
                           dof=6,
                           searchrad=True,
                           bins=None,
                           interp="spline",
                           cost='mutualinfo',
                           init=self.xfm_atlas2t1w_init)

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
                mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                                  self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                mgru.applyxfm(self.fa_path, self.atlas, self.temp2dwi_xfm,
                              self.dwi_aligned_atlas)
        else:
            # Create transform to align atlas to T1w using flirt
            mgru.align(self.atlas,
                       self.t1w_brain,
                       xfm=self.xfm_atlas2t1w_init,
                       init=None,
                       bins=None,
                       dof=6,
                       cost='mutualinfo',
                       searchrad=None,
                       interp="spline",
                       out=None,
                       sch=None)
            mgru.align(self.atlas,
                       self.t1w_brain,
                       xfm=self.xfm_atlas2t1w,
                       out=None,
                       dof=6,
                       searchrad=True,
                       bins=None,
                       interp="spline",
                       cost='mutualinfo',
                       init=self.xfm_atlas2t1w_init)

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
            mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                              self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            mgru.applyxfm(self.fa_path, self.atlas, self.temp2dwi_xfm,
                          self.dwi_aligned_atlas)

        # Set intensities to int
        self.atlas_img = nib.load(self.dwi_aligned_atlas)
        self.atlas_data = self.atlas_img.get_fdata().astype('int')
        #node_num = len(np.unique(self.atlas_data))
        #self.atlas_data[self.atlas_data>node_num] = 0
        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)
        nib.save(
            nib.Nifti1Image(self.atlas_data.astype(np.int32),
                            affine=self.atlas_img.affine,
                            header=self.atlas_img.header),
            self.dwi_aligned_atlas)
        cmd = 'fslmaths ' + self.dwi_aligned_atlas + ' -mas ' + self.nodif_B0_mask + ' -mas ' + self.wm_gm_int_in_dwi_bin + ' ' + self.dwi_aligned_atlas_wmgm_int
        os.system(cmd)

        return self.dwi_aligned_atlas_wmgm_int, self.dwi_aligned_atlas, self.aligned_atlas_t1mni