Beispiel #1
0
    def atlas2t1wmni_align(self, uatlas, atlas):
        """
        A function to perform atlas alignment from atlas --> T1_MNI.
        Tries nonlinear registration first, and if that fails, does a linear registration instead.
        """
        aligned_atlas_t1mni = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni.nii.gz")
        aligned_atlas_skull = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_skull.nii.gz")
        gm_mask_mni = "%s%s%s%s" % (self.anat_path, '/', self.t1w_name, "_gm_mask_t1w_mni.nii.gz")
        aligned_atlas_t1mni_gm = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni_gm.nii.gz")

        regutils.align(uatlas, self.t1_aligned_mni, init=None, xfm=None, out=aligned_atlas_t1mni, dof=12,
                       searchrad=True, interp="nearestneighbour", cost='mutualinfo')

        # Apply warp resulting from the inverse of T1w-->MNI created earlier
        regutils.apply_warp(self.t1w_brain, aligned_atlas_t1mni, aligned_atlas_skull, warp=self.mni2t1w_warp,
                            interp='nn', sup=True)

        # Apply warp resulting from the inverse MNI->T1w created earlier
        regutils.apply_warp(self.t1w_brain, self.gm_mask_thr, gm_mask_mni, warp=self.mni2t1w_warp, interp='nn',
                            sup=True)

        # Set intensities to int
        atlas_img = nib.load(aligned_atlas_t1mni)
        atlas_data = atlas_img.get_fdata().astype('int')
        nib.save(nib.Nifti1Image(atlas_data.astype(np.int32), affine=atlas_img.affine,
                                 header=atlas_img.header), aligned_atlas_t1mni)
        os.system("fslmaths {} -mas {} {}".format(aligned_atlas_t1mni, gm_mask_mni, aligned_atlas_t1mni_gm))

        return aligned_atlas_t1mni_gm
Beispiel #2
0
def test_align():
    """
    Test align functionality
    """
    import pkg_resources

    # Linear registration
    base_dir = str(Path(__file__).parent / "examples")
    anat_dir = f"{base_dir}/003/anat"
    inp = f"{anat_dir}/sub-003_T1w_brain.nii.gz"
    ref = pkg_resources.resource_filename(
        "pynets", f"templates/MNI152_T1_brain_2mm.nii.gz")
    out = f"{anat_dir}/highres2standard.nii.gz"
    xfm_out = f"{anat_dir}/highres2standard.mat"

    reg_utils.align(inp,
                    ref,
                    xfm=xfm_out,
                    out=out,
                    dof=12,
                    searchrad=True,
                    bins=256,
                    interp=None,
                    cost="mutualinfo",
                    sch=None,
                    wmseg=None,
                    init=None)

    highres2standard_linear = nib.load(out)
    assert highres2standard_linear is not None
Beispiel #3
0
def test_align():
    """
    Test align functionality
    """
    # Linear registrattion
    base_dir = str(Path(__file__).parent / "examples")
    anat_dir = base_dir + '/003/anat'
    inp = anat_dir + '/sub-003_T1w_brain.nii.gz'
    ref = anat_dir + '/MNI152_T1_2mm_brain.nii.gz'
    out = anat_dir + '/highres2standard.nii.gz'
    xfm_out = anat_dir + '/highres2standard.mat'

    reg_utils.align(inp,
                    ref,
                    xfm=xfm_out,
                    out=out,
                    dof=12,
                    searchrad=True,
                    bins=256,
                    interp=None,
                    cost="mutualinfo",
                    sch=None,
                    wmseg=None,
                    init=None)

    highres2standard_linear = nib.load(out)
    assert highres2standard_linear is not None
Beispiel #4
0
    def t1w2mni_align(self):
        """
        A function to perform alignment from T1w --> MNI.
        """

        # Create linear transform/ initializer T1w-->MNI
        regutils.align(self.t1w_brain, self.input_mni_brain, xfm=self.t12mni_xfm_init, bins=None, interp="spline",
                       out=None, dof=12, cost='mutualinfo', searchrad=True)

        # Attempt non-linear registration of T1 to MNI template
        try:
            print('Running non-linear registration: T1w-->MNI ...')
            # Use FNIRT to nonlinearly align T1 to MNI template
            regutils.align_nonlinear(self.t1w_brain, self.input_mni, xfm=self.t12mni_xfm_init, out=self.t1_aligned_mni,
                                     warp=self.warp_t1w2mni, ref_mask=self.input_mni_mask, config=self.input_mni_sched)

            # Get warp from MNI -> T1
            regutils.inverse_warp(self.t1w_brain, self.mni2t1w_warp, self.warp_t1w2mni)

            # Get mat from MNI -> T1
            os.system("convert_xfm -omat {} -inverse {}".format(self.mni2t1_xfm_init, self.t12mni_xfm_init))

        except RuntimeError('Error: FNIRT failed!'):
            pass
Beispiel #5
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI
        space --> dwi and CSF from T1w space --> dwi. First generates and
        performs dwi space alignment of avoidance/waypoint masks for
        tractography. First creates ventricle ROI. Then creates transforms
        from stock MNI template to dwi space. For this to succeed, must first
        have called both t1w2dwi_align.
        """
        import sys
        import time
        import os.path as op

        # Register Lateral Ventricles and Corpus Callosum rois to t1w
        if not op.isfile(self.mni_atlas):
            raise FileNotFoundError("FSL atlas for ventricle reference not"
                                    " found!")

        # Create transform to MNI atlas to T1w using flirt. This will be use to
        # transform the ventricles to dwi space.
        regutils.align(
            self.mni_atlas,
            self.input_mni_brain,
            xfm=self.xfm_roi2mni_init,
            init=None,
            bins=None,
            dof=6,
            cost="mutualinfo",
            searchrad=True,
            interp="spline",
            out=None,
        )
        time.sleep(0.5)

        if sys.platform.startswith('win') is False:
            try:
                nib.load(self.mni_vent_loc)
            except indexed_gzip.ZranError as e:
                print(e,
                      f"\nCannot load ventricle ROI. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)
        else:
            try:
                nib.load(self.mni_vent_loc)
            except ImportError as e:
                print(e, f"\nCannot load ventricle ROI. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(
            self.input_mni_brain,
            self.mni_vent_loc,
            self.xfm_roi2mni_init,
            self.vent_mask_mni,
        )
        time.sleep(0.5)
        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(
                self.t1w_brain,
                self.vent_mask_mni,
                self.vent_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)

            if sys.platform.startswith('win') is False:
                try:
                    nib.load(self.corpuscallosum)
                except indexed_gzip.ZranError as e:
                    print(e,
                          f"\nCannot load Corpus Callosum ROI. "
                          f"Do you have git-lfs installed?")
                    sys.exit(1)
            else:
                try:
                    nib.load(self.corpuscallosum)
                except ImportError as e:
                    print(e, f"\nCannot load Corpus Callosum ROI. "
                          f"Do you have git-lfs installed?")
                    sys.exit(1)

            regutils.apply_warp(
                self.t1w_brain,
                self.corpuscallosum,
                self.corpuscallosum_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )

        else:
            regutils.applyxfm(
                self.vent_mask_mni,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.vent_mask_t1w)
            time.sleep(0.5)
            regutils.applyxfm(
                self.corpuscallosum,
                self.t1w_brain,
                self.mni2t1_xfm,
                self.corpuscallosum_mask_t1w,
            )
            time.sleep(0.5)

        # Applyxfm tissue maps to dwi space
        if self.t1w_brain_mask is not None:
            regutils.applyxfm(
                self.ap_path,
                self.t1w_brain_mask,
                self.t1wtissue2dwi_xfm,
                self.t1w_brain_mask_in_dwi,
            )
            time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path,
            self.vent_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.vent_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path,
            self.csf_mask,
            self.t1wtissue2dwi_xfm,
            self.csf_mask_dwi)
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path, self.gm_mask, self.t1wtissue2dwi_xfm, self.gm_in_dwi
        )
        time.sleep(0.5)
        regutils.applyxfm(
            self.ap_path, self.wm_mask, self.t1wtissue2dwi_xfm, self.wm_in_dwi
        )
        time.sleep(0.5)

        regutils.applyxfm(
            self.ap_path,
            self.corpuscallosum_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.corpuscallosum_dwi,
        )
        time.sleep(0.5)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img = math_img("img > 0.10", img=thr_img)
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img = math_img("img > 0.15", img=thr_img)
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img = math_img("img > 0.95", img=thr_img)
        nib.save(thr_img, self.csf_mask_dwi_bin)

        # Threshold WM to binary in dwi space
        self.wm_in_dwi = regutils.apply_mask_to_image(self.wm_in_dwi,
                                                      self.wm_in_dwi_bin,
                                                      self.wm_in_dwi)
        time.sleep(0.5)
        # Threshold GM to binary in dwi space
        self.gm_in_dwi = regutils.apply_mask_to_image(self.gm_in_dwi,
                                                      self.gm_in_dwi_bin,
                                                      self.gm_in_dwi)
        time.sleep(0.5)
        # Threshold CSF to binary in dwi space
        self.csf_mask = regutils.apply_mask_to_image(self.csf_mask_dwi,
                                                     self.csf_mask_dwi_bin,
                                                     self.csf_mask_dwi)
        time.sleep(0.5)
        # Create ventricular CSF mask
        print("Creating Ventricular CSF mask...")
        os.system(
            f"fslmaths {self.vent_mask_dwi} -kernel sphere 10 -ero "
            f"-bin {self.vent_mask_dwi}"
        )
        time.sleep(1)
        os.system(
            f"fslmaths {self.csf_mask_dwi} -add {self.vent_mask_dwi} "
            f"-bin {self.vent_csf_in_dwi}"
        )
        time.sleep(1)
        print("Creating Corpus Callosum mask...")
        os.system(
            f"fslmaths {self.corpuscallosum_dwi} -mas {self.wm_in_dwi_bin} "
            f"-sub {self.vent_csf_in_dwi} "
            f"-bin {self.corpuscallosum_dwi}")
        time.sleep(1)
        # Create gm-wm interface image
        os.system(
            f"fslmaths {self.gm_in_dwi_bin} -mul {self.wm_in_dwi_bin} "
            f"-add {self.corpuscallosum_dwi} "
            f"-mas {self.B0_mask} -bin {self.wm_gm_int_in_dwi}")
        time.sleep(1)
        return
Beispiel #6
0
    def t1w2dwi_align(self):
        """
        A function to perform alignment from T1w_MNI --> DWI. Uses a local
        optimization cost function to get the two images close, and then uses
        bbr to obtain a good alignment of brain boundaries.
        Assumes input dwi is already preprocessed and brain extracted.
        """
        import time

        # Align T1w-->DWI
        regutils.align(
            self.ap_path,
            self.t1w_brain,
            xfm=self.t1w2dwi_xfm,
            bins=None,
            interp="spline",
            dof=6,
            cost="mutualinfo",
            out=None,
            searchrad=True,
            sch=None,
        )
        time.sleep(0.5)
        self.dwi2t1w_xfm = regutils.invert_xfm(self.t1w2dwi_xfm,
                                               self.dwi2t1w_xfm)
        time.sleep(0.5)
        if self.simple is False:
            # Flirt bbr
            try:
                print("Learning a Boundary-Based Mapping from T1w-->DWI ...")
                regutils.align(
                    self.fa_path,
                    self.t1w_brain,
                    wmseg=self.wm_edge,
                    xfm=self.dwi2t1w_bbr_xfm,
                    init=self.dwi2t1w_xfm,
                    bins=256,
                    dof=7,
                    searchrad=True,
                    interp="spline",
                    out=None,
                    cost="bbr",
                    sch="${FSLDIR}/etc/flirtsch/bbr.sch",
                )
                time.sleep(0.5)
                self.t1w2dwi_bbr_xfm = regutils.invert_xfm(
                    self.dwi2t1w_bbr_xfm, self.t1w2dwi_bbr_xfm)
                time.sleep(0.5)
                # Apply the alignment
                regutils.align(
                    self.t1w_brain,
                    self.ap_path,
                    init=self.t1w2dwi_bbr_xfm,
                    xfm=self.t1wtissue2dwi_xfm,
                    bins=None,
                    interp="spline",
                    dof=7,
                    cost="mutualinfo",
                    out=self.t1w2dwi,
                    searchrad=True,
                    sch=None,
                )
                time.sleep(0.5)
            except BaseException:
                # Apply the alignment
                regutils.align(
                    self.t1w_brain,
                    self.ap_path,
                    init=self.t1w2dwi_xfm,
                    xfm=self.t1wtissue2dwi_xfm,
                    bins=None,
                    interp="spline",
                    dof=7,
                    cost="mutualinfo",
                    out=self.t1w2dwi,
                    searchrad=True,
                    sch=None,
                )
                time.sleep(0.5)
        else:
            # Apply the alignment
            regutils.align(
                self.t1w_brain,
                self.ap_path,
                init=self.t1w2dwi_xfm,
                xfm=self.t1wtissue2dwi_xfm,
                bins=None,
                interp="spline",
                dof=6,
                cost="mutualinfo",
                out=self.t1w2dwi,
                searchrad=True,
                sch=None,
            )
            time.sleep(0.5)

        return
Beispiel #7
0
    def t1w2mni_align(self):
        """
        A function to perform alignment from T1w --> MNI template.
        """
        import time

        # Create linear transform/ initializer T1w-->MNI
        regutils.align(
            self.t1w_brain,
            self.input_mni_brain,
            xfm=self.t12mni_xfm_init,
            bins=None,
            interp="spline",
            out=None,
            dof=12,
            cost="mutualinfo",
            searchrad=True,
        )
        time.sleep(0.5)
        # Attempt non-linear registration of T1 to MNI template
        if self.simple is False:
            try:
                print(
                    f"Learning a non-linear mapping from T1w --> "
                    f"{self.template_name} ..."
                )
                # Use FNIRT to nonlinearly align T1 to MNI template
                regutils.align_nonlinear(
                    self.t1w_brain,
                    self.input_mni,
                    xfm=self.t12mni_xfm_init,
                    out=self.t1_aligned_mni,
                    warp=self.warp_t1w2mni,
                    ref_mask=self.input_mni_mask,
                )
                time.sleep(0.5)
                # Get warp from MNI -> T1
                regutils.inverse_warp(
                    self.t1w_brain, self.mni2t1w_warp, self.warp_t1w2mni
                )
                time.sleep(0.5)
                # Get mat from MNI -> T1
                self.mni2t1_xfm = regutils.invert_xfm(self.t12mni_xfm_init,
                                                      self.mni2t1_xfm)
                time.sleep(0.5)
            except BaseException:
                # Falling back to linear registration
                regutils.align(
                    self.t1w_brain,
                    self.input_mni_brain,
                    xfm=self.mni2t1_xfm,
                    init=self.t12mni_xfm_init,
                    bins=None,
                    dof=12,
                    cost="mutualinfo",
                    searchrad=True,
                    interp="spline",
                    out=self.t1_aligned_mni,
                    sch=None,
                )
                time.sleep(0.5)
                # Get mat from MNI -> T1
                self.mni2t1_xfm = regutils.invert_xfm(self.t12mni_xfm,
                                                      self.mni2t1_xfm)
                time.sleep(0.5)
        else:
            # Falling back to linear registration
            regutils.align(
                self.t1w_brain,
                self.input_mni_brain,
                xfm=self.t12mni_xfm,
                init=self.t12mni_xfm_init,
                bins=None,
                dof=12,
                cost="mutualinfo",
                searchrad=True,
                interp="spline",
                out=self.t1_aligned_mni,
                sch=None,
            )
            time.sleep(0.5)
            # Get mat from MNI -> T1
            self.t12mni_xfm = regutils.invert_xfm(self.mni2t1_xfm,
                                                  self.t12mni_xfm)
            time.sleep(0.5)
Beispiel #8
0
    def tissue2dwi_align(self):
        """
        A function to perform alignment of ventricle ROI's from MNI space --> dwi and CSF from T1w space --> dwi.
        First generates and performs dwi space alignment of avoidance/waypoint masks for tractography.
        First creates ventricle ROI. Then creates transforms from stock MNI template to dwi space.
        For this to succeed, must first have called both t1w2dwi_align and atlas2t1w2dwi_align.
        """

        # Create MNI-space ventricle mask
        print('Creating MNI-space ventricle ROI...')
        if not os.path.isfile(self.mni_atlas):
            raise ValueError('FSL atlas for ventricle reference not found!')
        os.system("fslroi {} {} 2 1".format(self.mni_atlas, self.rvent_out_file))
        os.system("fslroi {} {} 13 1".format(self.mni_atlas, self.lvent_out_file))
        os.system("fslmaths {} -add {} -thr 0.1 -bin {}".format(self.lvent_out_file, self.rvent_out_file,
                                                                self.mni_vent_loc))

        # Create transform to MNI atlas to T1w using flirt. This will be use to transform the ventricles to dwi space.
        regutils.align(self.mni_atlas, self.input_mni_brain, xfm=self.xfm_roi2mni_init, init=None, bins=None, dof=6,
                       cost='mutualinfo', searchrad=True, interp="spline", out=None)

        # Create transform to align roi to mni and T1w using flirt
        regutils.applyxfm(self.input_mni_brain, self.mni_vent_loc, self.xfm_roi2mni_init, self.vent_mask_mni)

        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            regutils.apply_warp(self.t1w_brain, self.vent_mask_mni, self.vent_mask_t1w, warp=self.mni2t1w_warp,
                                interp='nn', sup=True)

        # Applyxfm tissue maps to dwi space
        regutils.applyxfm(self.fa_path, self.vent_mask_t1w, self.t1wtissue2dwi_xfm, self.vent_mask_dwi)
        regutils.applyxfm(self.fa_path, self.csf_mask, self.t1wtissue2dwi_xfm, self.csf_mask_dwi)
        regutils.applyxfm(self.fa_path, self.gm_mask, self.t1wtissue2dwi_xfm, self.gm_in_dwi)
        regutils.applyxfm(self.fa_path, self.wm_mask, self.t1wtissue2dwi_xfm, self.wm_in_dwi)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.9] = 0
        nib.save(thr_img, self.csf_mask_dwi)

        # Threshold WM to binary in dwi space
        t_img = load_img(self.wm_in_dwi_bin)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        t_img = load_img(self.gm_in_dwi_bin)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        t_img = load_img(self.csf_mask_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.csf_mask_dwi_bin)

        # Create ventricular CSF mask
        print('Creating ventricular CSF mask...')
        os.system("fslmaths {} -kernel sphere 10 -ero -bin {}".format(self.vent_mask_dwi, self.vent_mask_dwi))
        os.system("fslmaths {} -add {} -bin {} ".format(self.csf_mask_dwi, self.vent_mask_dwi, self.vent_csf_in_dwi))

        # Create gm-wm interface image
        os.system("fslmaths {} -mul {} -mas {} -bin {}".format(self.gm_in_dwi_bin, self.wm_in_dwi_bin, self.B0_mask,
                                                               self.wm_gm_int_in_dwi))

        return
Beispiel #9
0
    def atlas2t1w2dwi_align(self, uatlas, atlas):
        """
        A function to perform atlas alignment atlas --> T1 --> dwi.
        Tries nonlinear registration first, and if that fails, does a linear registration instead. For this to succeed,
        must first have called t1w2dwi_align.
        """
        aligned_atlas_t1mni = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_mni.nii.gz")
        aligned_atlas_skull = "%s%s%s%s" % (self.anat_path, '/', atlas, "_t1w_skull.nii.gz")
        dwi_aligned_atlas = "%s%s%s%s" % (self.reg_path_img, '/', atlas, "_dwi_track.nii.gz")
        dwi_aligned_atlas_wmgm_int = "%s%s%s%s" % (self.reg_path_img, '/', atlas, "_dwi_track_wmgm_int.nii.gz")

        regutils.align(uatlas, self.t1_aligned_mni, init=None, xfm=None, out=aligned_atlas_t1mni, dof=12,
                       searchrad=True, interp="nearestneighbour", cost='mutualinfo')

        if self.simple is False:
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                regutils.apply_warp(self.t1w_brain, aligned_atlas_t1mni, aligned_atlas_skull,
                                    warp=self.mni2t1w_warp, interp='nn', sup=True)

                # Apply transform to dwi space
                regutils.align(aligned_atlas_skull, self.fa_path, init=self.t1wtissue2dwi_xfm, xfm=None,
                               out=dwi_aligned_atlas, dof=6, searchrad=True, interp="nearestneighbour",
                               cost='mutualinfo')
            except:
                print("Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                      "registration.")

                # Create transform to align atlas to T1w using flirt
                regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6,
                               cost='mutualinfo', searchrad=True, interp="spline", out=None, sch=None)
                regutils.align(uatlas, self.t1_aligned_mni, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True,
                               bins=None, interp="spline", cost='mutualinfo', init=self.xfm_atlas2t1w_init)

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a
                # transform from atlas ->(-> t1w ->)-> dwi
                regutils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                regutils.applyxfm(self.fa_path, uatlas, self.temp2dwi_xfm, dwi_aligned_atlas)
        else:
            # Create transform to align atlas to T1w using flirt
            regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w_init, init=None, bins=None, dof=6,
                           cost='mutualinfo', searchrad=None, interp="spline", out=None, sch=None)
            regutils.align(uatlas, self.t1w_brain, xfm=self.xfm_atlas2t1w, out=None, dof=6, searchrad=True,
                           bins=None, interp="spline", cost='mutualinfo', init=self.xfm_atlas2t1w_init)

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a
            # transform from atlas ->(-> t1w ->)-> dwi
            regutils.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm, self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            regutils.applyxfm(self.fa_path, uatlas, self.temp2dwi_xfm, dwi_aligned_atlas)

        # Set intensities to int
        atlas_img = nib.load(dwi_aligned_atlas)
        atlas_data = atlas_img.get_fdata().astype('int')
        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)
        nib.save(nib.Nifti1Image(atlas_data.astype(np.int32), affine=atlas_img.affine, header=atlas_img.header),
                 dwi_aligned_atlas)
        os.system("fslmaths {} -mas {} -mas {} {}".format(dwi_aligned_atlas, self.B0_mask, self.wm_gm_int_in_dwi_bin,
                                                          dwi_aligned_atlas_wmgm_int))

        return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Beispiel #10
0
    def t1w2dwi_align(self):
        """
        A function to perform alignment from T1w --> MNI and T1w_MNI --> DWI. Uses a local optimisation
        cost function to get the two images close, and then uses bbr to obtain a good alignment of brain boundaries.
        Assumes input dwi is already preprocessed and brain extracted.
        """

        # Create linear transform/ initializer T1w-->MNI
        regutils.align(self.t1w_brain, self.input_mni_brain, xfm=self.t12mni_xfm_init, bins=None, interp="spline",
                       out=None, dof=12, cost='mutualinfo', searchrad=True)

        # Attempt non-linear registration of T1 to MNI template
        if self.simple is False:
            try:
                print('Running non-linear registration: T1w-->MNI ...')
                # Use FNIRT to nonlinearly align T1 to MNI template
                regutils.align_nonlinear(self.t1w_brain, self.input_mni, xfm=self.t12mni_xfm_init,
                                         out=self.t1_aligned_mni, warp=self.warp_t1w2mni, ref_mask=self.input_mni_mask,
                                         config=self.input_mni_sched)

                # Get warp from MNI -> T1
                regutils.inverse_warp(self.t1w_brain, self.mni2t1w_warp, self.warp_t1w2mni)

                # Get mat from MNI -> T1
                os.system("convert_xfm -omat {} -inverse {}".format(self.mni2t1_xfm_init, self.t12mni_xfm_init))

            except RuntimeError('Error: FNIRT failed!'):
                pass
        else:
            # Falling back to linear registration
            regutils.align(self.t1w_brain, self.input_mni_brain, xfm=self.t12mni_xfm, init=self.t12mni_xfm_init,
                           bins=None, dof=12, cost='mutualinfo', searchrad=True, interp="spline",
                           out=self.t1_aligned_mni, sch=None)

        # Align T1w-->DWI
        regutils.align(self.fa_path, self.t1w_brain, xfm=self.t1w2dwi_xfm, bins=None, interp="spline", dof=6,
                       cost='mutualinfo', out=None, searchrad=True, sch=None)
        os.system("convert_xfm -omat {} -inverse {}".format(self.dwi2t1w_xfm, self.t1w2dwi_xfm))

        if self.simple is False:
            # Flirt bbr
            try:
                print('Running FLIRT BBR registration: T1w-->DWI ...')
                regutils.align(self.fa_path, self.t1w_brain, wmseg=self.wm_edge, xfm=self.dwi2t1w_bbr_xfm,
                               init=self.dwi2t1w_xfm, bins=256, dof=7, searchrad=True, interp="spline", out=None,
                               cost='bbr', sch="${FSLDIR}/etc/flirtsch/bbr.sch")
                os.system("convert_xfm -omat {} -inverse {}".format(self.t1w2dwi_bbr_xfm, self.dwi2t1w_bbr_xfm))

                # Apply the alignment
                regutils.align(self.t1w_brain, self.fa_path, init=self.t1w2dwi_bbr_xfm, xfm=self.t1wtissue2dwi_xfm,
                               bins=None, interp="spline", dof=7, cost='mutualinfo', out=self.t1w2dwi, searchrad=True,
                               sch=None)
            except RuntimeError('Error: FLIRT BBR failed!'):
                pass
        else:
            # Apply the alignment
            regutils.align(self.t1w_brain, self.fa_path, init=self.t1w2dwi_xfm, xfm=self.t1wtissue2dwi_xfm, bins=None,
                           interp="spline", dof=6, cost='mutualinfo', out=self.t1w2dwi, searchrad=True, sch=None)

        return
Beispiel #11
0
def atlas2t1w2dwi_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    mni2t1w_warp,
    t1_aligned_mni,
    ap_path,
    t1w2dwi_bbr_xfm,
    mni2t1_xfm,
    t1w2dwi_xfm,
    wm_gm_int_in_dwi,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    dwi_aligned_atlas,
    dwi_aligned_atlas_wmgm_int,
    B0_mask,
    simple,
):
    """
    A function to perform atlas alignment atlas --> T1 --> dwi.
    Tries nonlinear registration first, and if that fails, does a linear registration instead. For this to succeed,
    must first have called t1w2dwi_align.
    """
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive
    from pynets.registration import reg_utils as regutils
    from nilearn.image import math_img
    from nilearn.masking import intersect_masks

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("int32"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

            # Apply linear transformation from template to dwi space
            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

            regutils.align(
                aligned_atlas_skull,
                ap_path,
                init=t1w2dwi_bbr_xfm,
                out=dwi_aligned_atlas,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

        regutils.align(
            aligned_atlas_skull,
            ap_path,
            init=t1w2dwi_xfm,
            out=dwi_aligned_atlas,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    atlas_img = nib.load(dwi_aligned_atlas)
    wm_gm_img = nib.load(wm_gm_int_in_dwi)
    wm_gm_mask_img = math_img("img > 0", img=wm_gm_img)
    atlas_mask_img = math_img("img > 0", img=atlas_img)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )

    dwi_aligned_atlas_wmgm_int_img = intersect_masks(
        [wm_gm_mask_img, atlas_mask_img], threshold=0, connected=False)

    nib.save(atlas_img_corr, dwi_aligned_atlas)
    nib.save(dwi_aligned_atlas_wmgm_int_img, dwi_aligned_atlas_wmgm_int)

    os.system(
        f"fslmaths {dwi_aligned_atlas} -mas {B0_mask} {dwi_aligned_atlas} "
        f"2>/dev/null")

    os.system(f"fslmaths {dwi_aligned_atlas_wmgm_int} -mas {B0_mask} "
              f"{dwi_aligned_atlas_wmgm_int} 2>/dev/null")

    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        print("Warning! Non-consecutive integers found in parcellation...")

    atlas_img.uncache()
    atlas_img_corr.uncache()
    atlas_mask_img.uncache()
    wm_gm_img.uncache()
    wm_gm_mask_img.uncache()

    return dwi_aligned_atlas_wmgm_int, dwi_aligned_atlas, aligned_atlas_t1mni
Beispiel #12
0
def atlas2t1w_align(
    uatlas,
    uatlas_parcels,
    atlas,
    t1w_brain,
    t1w_brain_mask,
    t1_aligned_mni,
    mni2t1w_warp,
    mni2t1_xfm,
    gm_mask,
    aligned_atlas_t1mni,
    aligned_atlas_skull,
    aligned_atlas_gm,
    simple,
):
    """
    A function to perform atlas alignment from atlas --> T1w.
    """
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img
    from pynets.core.utils import checkConsecutive

    template_img = nib.load(t1_aligned_mni)
    if uatlas_parcels:
        uatlas_res_template = resample_to_img(nib.load(uatlas_parcels),
                                              template_img,
                                              interpolation="nearest")
    else:
        uatlas_res_template = resample_to_img(nib.load(uatlas),
                                              template_img,
                                              interpolation="nearest")
    uatlas_res_template_data = np.asarray(uatlas_res_template.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0

    uatlas_res_template = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint16"),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1mni)

    if simple is False:
        try:
            regutils.apply_warp(
                t1w_brain,
                aligned_atlas_t1mni,
                aligned_atlas_skull,
                warp=mni2t1w_warp,
                interp="nn",
                sup=True,
                mask=t1w_brain_mask,
            )

        except BaseException:
            print(
                "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template "
                "registration.")

            regutils.align(
                aligned_atlas_t1mni,
                t1w_brain,
                init=mni2t1_xfm,
                out=aligned_atlas_skull,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )

    else:
        regutils.align(
            aligned_atlas_t1mni,
            t1w_brain,
            init=mni2t1_xfm,
            out=aligned_atlas_skull,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

    os.system(f"fslmaths {aligned_atlas_skull} -mas {gm_mask} "
              f"{aligned_atlas_gm} 2>/dev/null")

    atlas_img = nib.load(aligned_atlas_gm)

    uatlas_res_template_data = np.asarray(atlas_img.dataobj)
    uatlas_res_template_data[
        uatlas_res_template_data != uatlas_res_template_data.astype(int)] = 0
    atlas_img_corr = nib.Nifti1Image(
        uatlas_res_template_data.astype("uint32"),
        affine=atlas_img.affine,
        header=atlas_img.header,
    )
    nib.save(atlas_img_corr, aligned_atlas_gm)
    final_dat = atlas_img_corr.get_fdata()
    unique_a = sorted(set(np.array(final_dat.flatten().tolist())))

    if not checkConsecutive(unique_a):
        old_count = len(np.unique(uatlas_res_template_data))
        new_count = len(unique_a)
        diff = np.abs(np.int(float(new_count) - float(old_count)))
        print("\nWarning! Non-consecutive integers found in parcellation...")
        print(f"Previous label count: {old_count}")
        print(f"New label count: {new_count}")
        print(f"Labels dropped: {diff}")
        if diff > 1:
            print('Grey-Matter mask too restrictive for this parcellation. '
                  'Falling back to the T1w mask...')
            os.system(f"fslmaths {aligned_atlas_skull} -mas {t1w_brain_mask} "
                      f"{aligned_atlas_gm} 2>/dev/null")
    template_img.uncache()

    return aligned_atlas_gm, aligned_atlas_skull
Beispiel #13
0
def RegisterParcellation2MNIFunc_align(uatlas, template, template_mask,
                                       t1w_brain, t1w2mni_xfm,
                                       aligned_atlas_t1w, aligned_atlas_mni,
                                       t1w2mni_warp, simple):
    """
    A function to perform atlas alignment from T1w atlas --> MNI.
    """
    import time
    from pynets.registration import reg_utils as regutils
    from nilearn.image import resample_to_img

    atlas_img = nib.load(uatlas)
    t1w_brain_img = nib.load(t1w_brain)

    uatlas_res_template = resample_to_img(atlas_img,
                                          t1w_brain_img,
                                          interpolation="nearest")

    uatlas_res_template = nib.Nifti1Image(
        np.asarray(uatlas_res_template.dataobj).astype('uint16'),
        affine=uatlas_res_template.affine,
        header=uatlas_res_template.header,
    )
    nib.save(uatlas_res_template, aligned_atlas_t1w)

    if simple is False:
        try:
            regutils.apply_warp(
                template,
                aligned_atlas_t1w,
                aligned_atlas_mni,
                warp=t1w2mni_warp,
                interp="nn",
                sup=True,
            )
            time.sleep(0.5)
        except BaseException:
            print("Warning: Atlas is not in correct dimensions, or input is "
                  "low quality,\nusing linear template registration.")

            regutils.align(
                aligned_atlas_t1w,
                template,
                init=t1w2mni_xfm,
                out=aligned_atlas_mni,
                dof=6,
                searchrad=True,
                interp="nearestneighbour",
                cost="mutualinfo",
            )
            time.sleep(0.5)
    else:
        regutils.align(
            aligned_atlas_t1w,
            template,
            init=t1w2mni_xfm,
            out=aligned_atlas_mni,
            dof=6,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )
        time.sleep(0.5)
    return aligned_atlas_mni
Beispiel #14
0
    def tissue2dwi_align(self):
        """
        alignment of ventricle ROI's from MNI space --> dwi and
        CSF from T1w space --> dwi
        A function to generate and perform dwi space alignment of avoidance/waypoint masks for tractography.
        First creates ventricle ROI. Then creates transforms from stock MNI template to dwi space.
        NOTE: for this to work, must first have called both t1w2dwi_align and atlas2t1w2dwi_align.
        """

        # Create MNI-space ventricle mask
        print('Creating MNI-space ventricle ROI...')
        if not os.path.isfile(self.mni_atlas):
            raise ValueError('FSL atlas for ventricle reference not found!')
        cmd = 'fslroi ' + self.mni_atlas + ' ' + self.rvent_out_file + ' 2 1'
        os.system(cmd)
        cmd = 'fslroi ' + self.mni_atlas + ' ' + self.lvent_out_file + ' 13 1'
        os.system(cmd)
        self.args = "%s%s%s" % (' -add ', self.rvent_out_file,
                                ' -thr 0.1 -bin ')
        cmd = 'fslmaths ' + self.lvent_out_file + self.args + self.mni_vent_loc
        os.system(cmd)

        # Create transform to MNI atlas to T1w using flirt. This will be use to transform the ventricles to dwi space.
        mgru.align(self.mni_atlas,
                   self.input_mni_brain,
                   xfm=self.xfm_roi2mni_init,
                   init=None,
                   bins=None,
                   dof=6,
                   cost='mutualinfo',
                   searchrad=True,
                   interp="spline",
                   out=None)

        # Create transform to align roi to mni and T1w using flirt
        mgru.applyxfm(self.input_mni_brain, self.mni_vent_loc,
                      self.xfm_roi2mni_init, self.vent_mask_mni)

        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            mgru.apply_warp(self.t1w_brain,
                            self.vent_mask_mni,
                            self.vent_mask_t1w,
                            warp=self.mni2t1w_warp,
                            interp='nn',
                            sup=True)

        # Applyxfm tissue maps to dwi space
        mgru.applyxfm(self.fa_path, self.vent_mask_t1w, self.t1wtissue2dwi_xfm,
                      self.vent_mask_dwi)
        mgru.applyxfm(self.fa_path, self.csf_mask, self.t1wtissue2dwi_xfm,
                      self.csf_mask_dwi)
        mgru.applyxfm(self.fa_path, self.gm_mask, self.t1wtissue2dwi_xfm,
                      self.gm_in_dwi)
        mgru.applyxfm(self.fa_path, self.wm_mask, self.t1wtissue2dwi_xfm,
                      self.wm_in_dwi)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.2] = 0
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img.get_fdata()[thr_img.get_fdata() < 0.9] = 0
        nib.save(thr_img, self.csf_mask_dwi)

        # Threshold WM to binary in dwi space
        self.t_img = load_img(self.wm_in_dwi_bin)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        self.t_img = load_img(self.gm_in_dwi_bin)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        self.t_img = load_img(self.csf_mask_dwi)
        self.mask = math_img('img > 0', img=self.t_img)
        self.mask.to_filename(self.csf_mask_dwi_bin)

        # Create ventricular CSF mask
        print('Creating ventricular CSF mask...')
        cmd = 'fslmaths ' + self.vent_mask_dwi + ' -kernel sphere 10 -ero -bin ' + self.vent_mask_dwi
        os.system(cmd)
        cmd = 'fslmaths ' + self.csf_mask_dwi + ' -add ' + self.vent_mask_dwi + ' -bin ' + self.vent_csf_in_dwi
        os.system(cmd)

        # Create gm-wm interface image
        cmd = 'fslmaths ' + self.gm_in_dwi_bin + ' -mul ' + self.wm_in_dwi_bin + ' -mas ' + self.nodif_B0_mask + ' -bin ' + self.wm_gm_int_in_dwi
        os.system(cmd)

        return
Beispiel #15
0
    def atlas2t1w2dwi_align(self, uatlas_select, atlas_select):
        """
        alignment from atlas --> T1 --> dwi
        A function to perform atlas alignment.
        Tries nonlinear registration first, and if that fails,
        does a linear registration instead.
        NOTE: for this to work, must first have called t1w2dwi_align.
        """
        self.atlas = uatlas_select
        self.atlas_name = atlas_select
        self.aligned_atlas_t1mni = "{}/{}_t1w_mni.nii.gz".format(
            self.basedir_path, self.atlas_name)
        self.aligned_atlas_skull = "{}/{}_t1w_skull.nii.gz".format(
            self.anat_path, self.atlas_name)
        self.dwi_aligned_atlas = "{}/{}_dwi_track.nii.gz".format(
            self.reg_path_img, self.atlas_name)
        self.dwi_aligned_atlas_wmgm_int = "{}/{}_dwi_track_wmgm_int.nii.gz".format(
            self.reg_path_img, self.atlas_name)

        mgru.align(self.atlas,
                   self.t1_aligned_mni,
                   init=None,
                   xfm=None,
                   out=self.aligned_atlas_t1mni,
                   dof=12,
                   searchrad=True,
                   interp="nearestneighbour",
                   cost='mutualinfo')

        if self.simple is False:
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                mgru.apply_warp(self.t1w_brain,
                                self.aligned_atlas_t1mni,
                                self.aligned_atlas_skull,
                                warp=self.mni2t1w_warp,
                                interp='nn',
                                sup=True)

                # Apply transform to dwi space
                mgru.align(self.aligned_atlas_skull,
                           self.fa_path,
                           init=self.t1wtissue2dwi_xfm,
                           xfm=None,
                           out=self.dwi_aligned_atlas,
                           dof=6,
                           searchrad=True,
                           interp="nearestneighbour",
                           cost='mutualinfo')
            except:
                print(
                    "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template registration."
                )

                # Create transform to align atlas to T1w using flirt
                mgru.align(self.atlas,
                           self.t1w_brain,
                           xfm=self.xfm_atlas2t1w_init,
                           init=None,
                           bins=None,
                           dof=6,
                           cost='mutualinfo',
                           searchrad=True,
                           interp="spline",
                           out=None,
                           sch=None)
                mgru.align(self.atlas,
                           self.t1_aligned_mni,
                           xfm=self.xfm_atlas2t1w,
                           out=None,
                           dof=6,
                           searchrad=True,
                           bins=None,
                           interp="spline",
                           cost='mutualinfo',
                           init=self.xfm_atlas2t1w_init)

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
                mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                                  self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                mgru.applyxfm(self.fa_path, self.atlas, self.temp2dwi_xfm,
                              self.dwi_aligned_atlas)
        else:
            # Create transform to align atlas to T1w using flirt
            mgru.align(self.atlas,
                       self.t1w_brain,
                       xfm=self.xfm_atlas2t1w_init,
                       init=None,
                       bins=None,
                       dof=6,
                       cost='mutualinfo',
                       searchrad=None,
                       interp="spline",
                       out=None,
                       sch=None)
            mgru.align(self.atlas,
                       self.t1w_brain,
                       xfm=self.xfm_atlas2t1w,
                       out=None,
                       dof=6,
                       searchrad=True,
                       bins=None,
                       interp="spline",
                       cost='mutualinfo',
                       init=self.xfm_atlas2t1w_init)

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
            mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                              self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            mgru.applyxfm(self.fa_path, self.atlas, self.temp2dwi_xfm,
                          self.dwi_aligned_atlas)

        # Set intensities to int
        self.atlas_img = nib.load(self.dwi_aligned_atlas)
        self.atlas_data = self.atlas_img.get_fdata().astype('int')
        #node_num = len(np.unique(self.atlas_data))
        #self.atlas_data[self.atlas_data>node_num] = 0
        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img('img > 0', img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)
        nib.save(
            nib.Nifti1Image(self.atlas_data.astype(np.int32),
                            affine=self.atlas_img.affine,
                            header=self.atlas_img.header),
            self.dwi_aligned_atlas)
        cmd = 'fslmaths ' + self.dwi_aligned_atlas + ' -mas ' + self.nodif_B0_mask + ' -mas ' + self.wm_gm_int_in_dwi_bin + ' ' + self.dwi_aligned_atlas_wmgm_int
        os.system(cmd)

        return self.dwi_aligned_atlas_wmgm_int, self.dwi_aligned_atlas, self.aligned_atlas_t1mni