예제 #1
0
def test_align(tmp_path):
    d = tmp_path / "sub"
    d.mkdir()
    temp_out1 = d / "omat.data"
    temp_out2 = d / "outnii.nii.gz"

    inp = input_dir
    ref = ref_dir
    xfm = temp_out1
    out = temp_out2

    mgr.align(inp, ref, xfm, out)
    # result_nii = nib.load(str(out))
    result_mat = np.loadtxt(str(xfm))
    result_mat = np.array(result_mat)
    ''' output_mat = np.array([[0.8271407155,  -0.04727642977,  0.006816218756,  15.99814178],
                           [-0.005617838577,  0.8770838128,  -0.000978902222,  -0.3486071619],
                           [-0.006664295282,  -0.04974933234,  0.8076537939,  -18.14272133],
                           [0,  0,  0,  1]])  '''

    output_mat = np.array(
        [[0.8271407155, -0.04727642977, 0.006816218756, 15.99814178],
         [-0.005617838577, 0.8770838128, -0.000978902222, -0.3486071619],
         [-0.006664295282, -0.04974933234, 0.8076537939, -18.14272133],
         [0., 0., 0., 1.]])

    assert np.allclose(result_mat[3, :], output_mat[3, :])
예제 #2
0
def test_align(tmp_path):
    d = tmp_path / "sub"
    d.mkdir()
    omat_out_temp_path = d / "omat.data"
    outnii_out_temp_path = d / "outnii.nii.gz"

    # set input/ouput data paths
    align_in_path = r"../test_data/inputs/align/sub-0025864_ses-1_T1w.nii.gz"
    ref_in_path = r"../test_data/inputs/align/MNI152_T1_2mm_brain.nii.gz"
    outnii_out_cntrl_path = r"../test_data/outputs/align/outnii.nii.gz"

    # call function
    inp = align_in_path
    ref = ref_in_path
    xfm = omat_out_temp_path
    out = outnii_out_temp_path
    mgr.align(inp, ref, xfm, out)

    # load function outputs
    outnii_out_temp = nib.load(str(out)).get_fdata()
    X = torch.from_numpy(outnii_out_temp)

    # load output data
    outnii_out_cntrl = nib.load(str(outnii_out_cntrl_path)).get_fdata()
    Y = torch.from_numpy(outnii_out_cntrl)

    # calculate dice loss function
    iflat = X.contiguous().view(-1)
    tflat = Y.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    X_sum = torch.sum(iflat * iflat)
    Y_sum = torch.sum(tflat * tflat)

    ans = 1 - ((2. * intersection + 1.) / (X_sum + Y_sum + 1.))

    # assert
    assert ans < 0.3
    def tissue2dwi_align(self):
        """alignment of ventricle and CC ROI's from MNI space --> dwi and CC and CSF from T1w space --> dwi
        A function to generate and perform dwi space alignment of avoidance/waypoint masks for tractography.
        First creates ventricle and CC ROI. Then creates transforms from stock MNI template to dwi space.
        NOTE: for this to work, must first have called both t1w2dwi_align and atlas2t1w2dwi_align.

        Raises
        ------
        ValueError
            Raised if FSL atlas for ventricle reference not found
        """

        # Create MNI-space ventricle mask
        print("Creating MNI-space ventricle ROI...")
        if not os.path.isfile(self.mni_atlas):
            raise ValueError("FSL atlas for ventricle reference not found!")
        cmd = "fslmaths " + self.mni_vent_loc + " -thr 0.1 -bin " + self.mni_vent_loc
        os.system(cmd)

        cmd = "fslmaths " + self.corpuscallosum + " -bin " + self.corpuscallosum
        os.system(cmd)

        cmd = ("fslmaths " + self.corpuscallosum + " -sub " +
               self.mni_vent_loc + " -bin " + self.corpuscallosum)
        os.system(cmd)

        # Create a transform from the atlas onto T1w. This will be used to transform the ventricles to dwi space.
        mgru.align(
            self.mni_atlas,
            self.input_mni,
            xfm=self.xfm_roi2mni_init,
            init=None,
            bins=None,
            dof=6,
            cost="mutualinfo",
            searchrad=True,
            interp="spline",
            out=None,
        )

        # Create transform to align roi to mni and T1w using flirt
        mgru.applyxfm(self.input_mni, self.mni_vent_loc, self.xfm_roi2mni_init,
                      self.vent_mask_mni)

        if self.simple is False:
            # Apply warp resulting from the inverse MNI->T1w created earlier
            mgru.apply_warp(
                self.t1w_brain,
                self.vent_mask_mni,
                self.vent_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )

            # Apply warp resulting from the inverse MNI->T1w created earlier
            mgru.apply_warp(
                self.t1w_brain,
                self.corpuscallosum,
                self.corpuscallosum_mask_t1w,
                warp=self.mni2t1w_warp,
                interp="nn",
                sup=True,
            )

        # Applyxfm tissue maps to dwi space
        mgru.applyxfm(
            self.nodif_B0,
            self.vent_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.vent_mask_dwi,
        )
        mgru.applyxfm(
            self.nodif_B0,
            self.corpuscallosum_mask_t1w,
            self.t1wtissue2dwi_xfm,
            self.corpuscallosum_dwi,
        )
        mgru.applyxfm(self.nodif_B0, self.csf_mask, self.t1wtissue2dwi_xfm,
                      self.csf_mask_dwi)
        mgru.applyxfm(self.nodif_B0, self.gm_mask, self.t1wtissue2dwi_xfm,
                      self.gm_in_dwi)
        mgru.applyxfm(self.nodif_B0, self.wm_mask, self.t1wtissue2dwi_xfm,
                      self.wm_in_dwi)

        # Threshold WM to binary in dwi space
        thr_img = nib.load(self.wm_in_dwi)
        thr_img.get_data()[thr_img.get_data() < 0.15] = 0
        nib.save(thr_img, self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        thr_img = nib.load(self.gm_in_dwi)
        thr_img.get_data()[thr_img.get_data() < 0.15] = 0
        nib.save(thr_img, self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        thr_img = nib.load(self.csf_mask_dwi)
        thr_img.get_data()[thr_img.get_data() < 0.99] = 0
        nib.save(thr_img, self.csf_mask_dwi)

        # Threshold WM to binary in dwi space
        self.t_img = load_img(self.wm_in_dwi_bin)
        self.mask = math_img("img > 0", img=self.t_img)
        self.mask.to_filename(self.wm_in_dwi_bin)

        # Threshold GM to binary in dwi space
        self.t_img = load_img(self.gm_in_dwi_bin)
        self.mask = math_img("img > 0", img=self.t_img)
        self.mask.to_filename(self.gm_in_dwi_bin)

        # Threshold CSF to binary in dwi space
        self.t_img = load_img(self.csf_mask_dwi)
        self.mask = math_img("img > 0", img=self.t_img)
        self.mask.to_filename(self.csf_mask_dwi_bin)

        # Create ventricular CSF mask
        print("Creating ventricular CSF mask...")
        cmd = ("fslmaths " + self.vent_mask_dwi +
               " -kernel sphere 10 -ero -bin " + self.vent_mask_dwi)
        os.system(cmd)
        print("Creating Corpus Callosum mask...")
        cmd = ("fslmaths " + self.corpuscallosum_dwi + " -mas " +
               self.wm_in_dwi_bin + " -bin " + self.corpuscallosum_dwi)
        os.system(cmd)
        cmd = ("fslmaths " + self.csf_mask_dwi + " -add " +
               self.vent_mask_dwi + " -bin " + self.vent_csf_in_dwi)
        os.system(cmd)

        # Create gm-wm interface image
        cmd = ("fslmaths " + self.gm_in_dwi_bin + " -mul " +
               self.wm_in_dwi_bin + " -add " + self.corpuscallosum_dwi +
               " -sub " + self.vent_csf_in_dwi + " -mas " +
               self.nodif_B0_mask + " -bin " + self.wm_gm_int_in_dwi)
        os.system(cmd)

        return
    def atlas2t1w2dwi_align(self, atlas, dsn=True):
        """alignment from atlas to t1w to dwi. A function to perform atlas alignmet. Tries nonlinear registration first, and if that fails, does a liner
        registration instead.
        Note: for this to work, must first have called t1w2dwi_align.

        Parameters
        ----------
        atlas : str
            path to atlas file you want to use
        dsn : bool, optional
            is your space for tractography native-dsn, by default True

        Returns
        -------
        str
            path to aligned atlas file
        """

        self.atlas = atlas
        self.atlas_name = self.atlas.split("/")[-1].split(".")[0]
        self.aligned_atlas_t1mni = "{}/{}_aligned_atlas_t1w_mni.nii.gz".format(
            self.namer.dirs["tmp"]["reg_a"], self.atlas_name)
        self.aligned_atlas_skull = "{}/{}_aligned_atlas_skull.nii.gz".format(
            self.namer.dirs["tmp"]["reg_a"], self.atlas_name)
        self.dwi_aligned_atlas = "{}/{}_aligned_atlas.nii.gz".format(
            self.namer.dirs["output"]["reg_anat"], self.atlas_name)
        # self.dwi_aligned_atlas_mask = "{}/{}_aligned_atlas_mask.nii.gz".format(self.namer.dirs['tmp']['reg_a'], self.atlas_name)

        mgru.align(
            self.atlas,
            self.t1_aligned_mni,
            init=None,
            xfm=None,
            out=self.aligned_atlas_t1mni,
            dof=12,
            searchrad=True,
            interp="nearestneighbour",
            cost="mutualinfo",
        )

        if (self.simple is False) and (dsn is False):
            try:
                # Apply warp resulting from the inverse of T1w-->MNI created earlier
                mgru.apply_warp(
                    self.t1w_brain,
                    self.aligned_atlas_t1mni,
                    self.aligned_atlas_skull,
                    warp=self.mni2t1w_warp,
                    interp="nn",
                    sup=True,
                )

                # Apply transform to dwi space
                mgru.align(
                    self.aligned_atlas_skull,
                    self.nodif_B0,
                    init=self.t1wtissue2dwi_xfm,
                    xfm=None,
                    out=self.dwi_aligned_atlas,
                    dof=6,
                    searchrad=True,
                    interp="nearestneighbour",
                    cost="mutualinfo",
                )
            except:
                print(
                    "Warning: Atlas is not in correct dimensions, or input is low quality,\nusing linear template registration."
                )
                # Create transform to align atlas to T1w using flirt
                mgru.align(
                    self.atlas,
                    self.t1w_brain,
                    xfm=self.xfm_atlas2t1w_init,
                    init=None,
                    bins=None,
                    dof=6,
                    cost="mutualinfo",
                    searchrad=True,
                    interp="spline",
                    out=None,
                    sch=None,
                )
                mgru.align(
                    self.atlas,
                    self.t1_aligned_mni,
                    xfm=self.xfm_atlas2t1w,
                    out=None,
                    dof=6,
                    searchrad=True,
                    bins=None,
                    interp="spline",
                    cost="mutualinfo",
                    init=self.xfm_atlas2t1w_init,
                )

                # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
                mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                                  self.temp2dwi_xfm)

                # Apply linear transformation from template to dwi space
                mgru.applyxfm(self.nodif_B0, self.atlas, self.temp2dwi_xfm,
                              self.dwi_aligned_atlas)
        elif dsn is False:
            # Create transform to align atlas to T1w using flirt
            mgru.align(
                self.atlas,
                self.t1w_brain,
                xfm=self.xfm_atlas2t1w_init,
                init=None,
                bins=None,
                dof=6,
                cost="mutualinfo",
                searchrad=None,
                interp="spline",
                out=None,
                sch=None,
            )
            mgru.align(
                self.atlas,
                self.t1w_brain,
                xfm=self.xfm_atlas2t1w,
                out=None,
                dof=6,
                searchrad=True,
                bins=None,
                interp="spline",
                cost="mutualinfo",
                init=self.xfm_atlas2t1w_init,
            )

            # Combine our linear transform from t1w to template with our transform from dwi to t1w space to get a transform from atlas ->(-> t1w ->)-> dwi
            mgru.combine_xfms(self.xfm_atlas2t1w, self.t1wtissue2dwi_xfm,
                              self.temp2dwi_xfm)

            # Apply linear transformation from template to dwi space
            mgru.applyxfm(self.nodif_B0, self.atlas, self.temp2dwi_xfm,
                          self.dwi_aligned_atlas)
        else:
            pass

        # Set intensities to int
        if dsn is False:
            self.atlas_img = nib.load(self.dwi_aligned_atlas)
        else:
            self.atlas_img = nib.load(self.aligned_atlas_t1mni)
        self.atlas_data = np.around(self.atlas_img.get_data()).astype("int16")
        node_num = len(np.unique(self.atlas_data))
        self.atlas_data[self.atlas_data > node_num] = 0

        t_img = load_img(self.wm_gm_int_in_dwi)
        mask = math_img("img > 0", img=t_img)
        mask.to_filename(self.wm_gm_int_in_dwi_bin)

        if dsn is False:
            nib.save(
                nib.Nifti1Image(
                    self.atlas_data.astype(np.int32),
                    affine=self.atlas_img.affine,
                    header=self.atlas_img.header,
                ),
                self.dwi_aligned_atlas,
            )
            return self.dwi_aligned_atlas
        else:
            nib.save(
                nib.Nifti1Image(
                    self.atlas_data.astype(np.int32),
                    affine=self.atlas_img.affine,
                    header=self.atlas_img.header,
                ),
                self.aligned_atlas_t1mni,
            )
            return self.aligned_atlas_t1mni
    def t1w2dwi_align(self):
        """Alignment from t1w to mni, making t1w_mni, and t1w_mni to dwi. A function to perform self alignment. Uses a local optimisation cost function to get the
        two images close, and then uses bbr to obtain a good alignment of brain boundaries. Assumes input dwi is already preprocessed and brain extracted.
        """

        # Create linear transform/ initializer T1w-->MNI
        mgru.align(
            self.t1w_brain,
            self.input_mni,
            xfm=self.t12mni_xfm_init,
            bins=None,
            interp="spline",
            out=None,
            dof=12,
            cost="mutualinfo",
            searchrad=True,
        )

        # Attempt non-linear registration of T1 to MNI template
        if self.simple is False:
            try:
                print("Running non-linear registration: T1w-->MNI ...")
                # Use FNIRT to nonlinearly align T1 to MNI template
                mgru.align_nonlinear(
                    self.t1w_brain,
                    self.input_mni,
                    xfm=self.t12mni_xfm_init,
                    out=self.t1_aligned_mni,
                    warp=self.warp_t1w2mni,
                    ref_mask=self.input_mni_mask,
                    config=self.input_mni_sched,
                )

                # Get warp from MNI -> T1
                mgru.inverse_warp(self.t1w_brain, self.mni2t1w_warp,
                                  self.warp_t1w2mni)

                # Get mat from MNI -> T1
                cmd = ("convert_xfm -omat " + self.mni2t1_xfm_init +
                       " -inverse " + self.t12mni_xfm_init)
                print(cmd)
                os.system(cmd)

            except RuntimeError("Error: FNIRT failed!"):
                pass
        else:
            # Falling back to linear registration
            mgru.align(
                self.t1w_brain,
                self.input_mni,
                xfm=self.t12mni_xfm,
                init=self.t12mni_xfm_init,
                bins=None,
                dof=12,
                cost="mutualinfo",
                searchrad=True,
                interp="spline",
                out=self.t1_aligned_mni,
                sch=None,
            )

        # Align T1w-->DWI
        mgru.align(
            self.nodif_B0,
            self.t1w_brain,
            xfm=self.t1w2dwi_xfm,
            bins=None,
            interp="spline",
            dof=6,
            cost="mutualinfo",
            out=None,
            searchrad=True,
            sch=None,
        )
        cmd = "convert_xfm -omat " + self.dwi2t1w_xfm + " -inverse " + self.t1w2dwi_xfm
        print(cmd)
        os.system(cmd)

        if self.simple is False:
            # Flirt bbr
            try:
                print("Running FLIRT BBR registration: T1w-->DWI ...")
                mgru.align(
                    self.nodif_B0,
                    self.t1w_brain,
                    wmseg=self.wm_edge,
                    xfm=self.dwi2t1w_bbr_xfm,
                    init=self.dwi2t1w_xfm,
                    bins=256,
                    dof=7,
                    searchrad=True,
                    interp="spline",
                    out=None,
                    cost="bbr",
                    finesearch=5,
                    sch="${FSLDIR}/etc/flirtsch/bbr.sch",
                )
                cmd = ("convert_xfm -omat " + self.t1w2dwi_bbr_xfm +
                       " -inverse " + self.dwi2t1w_bbr_xfm)
                os.system(cmd)

                # Apply the alignment
                mgru.align(
                    self.t1w_brain,
                    self.nodif_B0,
                    init=self.t1w2dwi_bbr_xfm,
                    xfm=self.t1wtissue2dwi_xfm,
                    bins=None,
                    interp="spline",
                    dof=7,
                    cost="mutualinfo",
                    out=self.t1w2dwi,
                    searchrad=True,
                    sch=None,
                )
            except RuntimeError("Error: FLIRT BBR failed!"):
                pass
        else:
            # Apply the alignment
            mgru.align(
                self.t1w_brain,
                self.nodif_B0,
                init=self.t1w2dwi_xfm,
                xfm=self.t1wtissue2dwi_xfm,
                bins=None,
                interp="spline",
                dof=6,
                cost="mutualinfo",
                out=self.t1w2dwi,
                searchrad=True,
                sch=None,
            )

        return