Ejemplo n.º 1
0
 def iexp(self, q=None, grad=False, cache_result=False, recompute=True):
     if q is None:
         q = self.dat
     if grad:
         recompute = True
     if recompute or self._cache is None:
         iaff = linalg._expm(-q, self.basis, grad_X=grad)
     else:
         iaff = self._cache
     if cache_result:
         self._cache = iaff[0] if grad else iaff
     return iaff
Ejemplo n.º 2
0
 def exp(self, q=None, grad=False, cache_result=False, recompute=True):
     if q is None:
         q = self.dat
     if grad:
         recompute = True
     if recompute or getattr(self, '_cache') is None:
         aff = linalg._expm(q, self.basis, grad_X=grad)
     else:
         aff = self._cache
     if cache_result:
         self._cache = aff[0] if grad else aff
     return aff
Ejemplo n.º 3
0
def _mean_space(Mat, Dim, vx=None):
    """Compute a (mean) model space from individual spaces.

    Args:
        Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4).
        Dim (torch.tensor): N subjects' dimensions (N, 3).
        vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input).

    Returns:
        mat (torch.tensor): Mean orientation matrix (4, 4).
        dim (torch.tensor): Mean dimensions (3,).
        vx (torch.tensor): Mean voxel size (3,).

    Authors:
        John Ashburner, as part of the SPM12 software.

    """
    device = Mat.device
    dtype = Mat.dtype
    N = Mat.shape[0]  # Number of subjects
    inf = float('inf')
    one = torch.tensor(1.0, device=device, dtype=dtype)
    if vx is None:
        vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype)
    if isinstance(vx, float) or isinstance(vx, int):
        vx = (vx, ) * 3
    if isinstance(vx, tuple) and len(vx) == 3:
        vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype)
    # To float64
    Mat = Mat.type(dtype)
    Dim = Dim.type(dtype)
    # Get affine basis
    basis = 'SE'
    dim = 3 if Dim[0, 2] > 1 else 2
    B = affine_basis(basis, dim, device=device, dtype=dtype)

    # Find combination of 90 degree rotations and flips that brings all
    # the matrices closest to axial
    Mat0 = Mat.clone()
    pmatrix = torch.tensor(
        [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]],
        device=device)

    for n in range(N):  # Loop over subjects
        vx1 = voxel_size(Mat[n, ...])
        R = Mat[n, ...].mm(
            torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1]
        minss = inf
        minR = torch.eye(3, dtype=dtype, device=device)
        for i in range(6):  # Permute (= 'rotate + flip') axes
            R1 = torch.zeros((3, 3), dtype=dtype, device=device)
            R1[pmatrix[i, 0], 0] = 1
            R1[pmatrix[i, 1], 1] = 1
            R1[pmatrix[i, 2], 2] = 1
            for j in range(8):  # Mirror (= 'flip') axes
                fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1]
                F = torch.diag(torch.tensor(fd, dtype=dtype, device=device))
                R2 = F.mm(R1)
                ss = torch.sum((R.mm(R2.inverse()) -
                                torch.eye(3, dtype=dtype, device=device))**2)
                if ss < minss:
                    minss = ss
                    minR = R2
        rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1))
        R2 = minR.inverse()
        R22 = R2.mm((torch.div(
            torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') -
                     1) * rdim)
        minR = torch.cat((R2, R22), dim=1)
        minR = torch.cat(
            (minR, torch.tensor([0, 0, 0, 1], device=device,
                                dtype=dtype)[None, ...]),
            dim=0)
        Mat[n, ...] = Mat[n, ...].mm(minR)

    # Average of the matrices in Mat
    mat = meanm(Mat)

    # If average involves shears, then find the closest matrix that does not
    # require them.
    C_ix = torch.tensor(
        [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
        device=device)  # column-major ordering from (4, 4) tensor
    p = _imatrix(mat)
    if torch.sum(p[[9, 10, 11]]**2) > 1e-8:
        B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype)
        B2[0, 0, 0] = 1
        B2[1, 1, 1] = 1
        B2[2, 2, 2] = 1

        p = torch.zeros(9, device=device, dtype=dtype)
        for n_iter in range(10000):
            # Rotations + Translations
            R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True)
            # Zooms
            Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True)

            M = R.mm(Z)
            dM = torch.zeros((4, 4, 9), device=device, dtype=dtype)
            for n in range(6):
                dM[..., n] = dR[n, ...].mm(Z)
            for n in range(3):
                dM[..., 6 + n] = R.mm(dZ[n, ...])
            dM = dM.reshape((16, 9))
            d = M.flatten() - mat.flatten()
            gr = dM.t().mm(d[..., None])
            Hes = dM.t().mm(dM)
            p = p - lmdiv(Hes, gr)[:, 0]
            if torch.sum(gr**2) < 1e-8:
                break
        mat = M.clone()

    # Set required voxel size
    vx_out = vx.clone()
    vx = voxel_size(mat)
    vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)]
    mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag())
    vx = voxel_size(mat)

    # Ensure that the FoV covers all images, with a few voxels to spare
    mn_all = torch.zeros([3, N], device=device, dtype=dtype)
    mx_all = torch.zeros([3, N], device=device, dtype=dtype)
    for n in range(N):
        dm = Dim[n, ...]
        corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]],
                                [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]],
                                [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]],
                                [1, 1, 1, 1, 1, 1, 1, 1]],
                               device=device,
                               dtype=dtype)
        M = lmdiv(mat, Mat0[n])
        vx1 = M[:-1, :].mm(corners)
        mx_all[..., n] = torch.max(vx1, dim=1)[0]
        mn_all[..., n] = torch.min(vx1, dim=1)[0]
    mx = mx_all.max(dim=1)[0]
    mn = mn_all.min(dim=1)[0]
    mx = torch.ceil(mx)
    mn = torch.floor(mn)

    # Make output dimensions and orientation matrix
    dim = mx - mn + 1  # Output dimensions
    off = torch.tensor([0, 0, 0], device=device, dtype=dtype)
    mat = mat.mm(
        torch.tensor([[1, 0, 0, mn[0] -
                       (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)],
                      [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]],
                     device=device,
                     dtype=dtype))

    return mat, dim, vx
Ejemplo n.º 4
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Ejemplo n.º 5
0
def fit_affine_tpm(dat,
                   tpm,
                   affine=None,
                   affine_tpm=None,
                   weights=None,
                   basis='affine',
                   fwhm=None,
                   joint=False,
                   prm=None,
                   max_iter_gn=100,
                   max_iter_em=32,
                   max_line_search=6,
                   progressive=False,
                   verbose=1):
    """

    Parameters
    ----------
    dat : (B, J|1, *spatial) tensor
    tpm : (B|1, K, *spatial) tensor
    affine : (4, 4) tensor
    affine_tpm : (4, 4) tensor
    weights : (B, 1, *spatial) tensor
    basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
    fwhm : float, default=J/32
    joint : bool, default=False
    max_iter_gn : int, default=100
    max_iter_em : int, default=32
    max_line_search : int, default=12
    progressive : bool, default=False

    Returns
    -------
    mi : (B,) tensor
    aff : (B, 4, 4) tensor
    prm : (B, F) tensor

    """
    dim = dat.dim() - 2

    # ------------------------------------------------------------------
    #       RECURSIVE PROGRESSIVE FIT
    # ------------------------------------------------------------------
    if progressive:
        nb_se = dim * (dim + 1) // 2
        nb_aff = dim * (dim + 1)
        basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'}
        basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se}
        basis = convert_basis(basis)
        next_basis = basis_recursion.get(basis, None)
        if next_basis:
            *_, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine,
                                     affine_tpm,
                                     weights,
                                     basis=next_basis,
                                     fwhm=fwhm,
                                     joint=joint,
                                     prm=prm,
                                     max_iter_gn=max_iter_gn,
                                     max_iter_em=max_iter_em,
                                     max_line_search=max_line_search)
            B = len(dat)
            F = basis_nb_feat[basis]
            prm0 = prm
            prm = prm0.new_zeros([1 if joint else B, F])
            if basis == 'SE':
                prm[:, :dim] = prm0[:, :dim]
            else:
                nb_se = dim * (dim + 1) // 2
                prm[:, :nb_se] = prm0[:, :nb_se]
                if basis == 'Aff+':
                    prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5))

    basis_name = basis

    # ------------------------------------------------------------------
    #       PREPARE
    # ------------------------------------------------------------------

    B = len(dat)
    if affine is None:
        affine = spatial.affine_default(dat.shape[-dim:])
    if affine_tpm is None:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:])
    affine = affine.to(**utils.backend(tpm))
    affine_tpm = affine_tpm.to(**utils.backend(tpm))
    shape = dat.shape[-dim:]

    tpm = tpm.to(dat.device)
    basis = make_basis(basis, dim, **utils.backend(tpm))
    F = len(basis)

    if prm is None:
        prm = tpm.new_zeros([1 if joint else B, F])
    aff, gaff = linalg._expm(prm, basis, grad_X=True)

    em_opt = dict(fwhm=fwhm,
                  max_iter=max_iter_em,
                  weights=weights,
                  verbose=verbose - 2)
    drv_opt = dict(weights=weights)
    pull_opt = dict(bound='replicate', extrapolate=True)

    # ------------------------------------------------------------------
    #       OPTIMIZE
    # ------------------------------------------------------------------
    prior = None
    mi = torch.as_tensor(-float('inf'))
    delta = torch.zeros_like(prm)
    for n_iter in range(max_iter_gn):

        # --------------------------------------------------------------
        #       LINE SEARCH
        # --------------------------------------------------------------
        prior0, prm0, mi0 = prior, prm, mi
        armijo = 1
        success = False
        for n_ls in range(max_line_search):

            # --- take a step ------------------------------------------
            prm = prm0 - armijo * delta

            # --- build transformation field ---------------------------
            aff, gaff = linalg._expm(prm, basis, grad_X=True)
            phi = lmdiv(affine_tpm, mm(aff, affine))
            phi = spatial.affine_grid(phi, shape)

            # --- warp TPM ---------------------------------------------
            mov = spatial.grid_pull(tpm, phi, **pull_opt)

            # --- mutual info ------------------------------------------
            mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
            mi = mi / Nm

            success = mi.sum() > mi0.sum()
            if verbose >= 2:
                end = '\n' if verbose >= 3 else '\r'
                happy = ':D' if success else ':('
                print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}',
                      end=end)
            if success:
                break
            armijo *= 0.5
        # if verbose == 2:
        #     print('')

        # --------------------------------------------------------------
        #       DID IT WORK?
        # --------------------------------------------------------------

        if not success:
            prior, prm, mi = prior0, prm0, mi0
            break

        # DEBUG
        # plot_registration(dat, mov, f'{basis_name} | {n_iter}')

        space = ' ' * max(0, 6 - len(basis_name))
        if verbose >= 1:
            end = '\n' if verbose >= 2 else '\r'
            print(
                f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}',
                end=end)

        if mi.mean() - mi0.mean() < 1e-5:
            break

        # --------------------------------------------------------------
        #       GAUSS-NEWTON
        # --------------------------------------------------------------

        # --- derivatives ----------------------------------------------
        g, h = derivatives_intensity(mov, dat, prior, **drv_opt)

        # --- chain rule -----------------------------------------------
        gmov = spatial.grid_grad(tpm, phi, **pull_opt)
        if joint and len(mov) == 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)
        else:
            gmov = gmov.expand([B, *gmov.shape[1:]])
        gaff = lmdiv(affine_tpm, mm(gaff, affine))
        g, h = chain_rule(g, h, gmov, gaff, maj=False)
        del gmov

        if joint and len(g) > 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)

        # --- Gauss-Newton ---------------------------------------------
        delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)

    if verbose == 1:
        print('')
    return mi, aff, prm