def iexp(self, q=None, grad=False, cache_result=False, recompute=True): if q is None: q = self.dat if grad: recompute = True if recompute or self._cache is None: iaff = linalg._expm(-q, self.basis, grad_X=grad) else: iaff = self._cache if cache_result: self._cache = iaff[0] if grad else iaff return iaff
def exp(self, q=None, grad=False, cache_result=False, recompute=True): if q is None: q = self.dat if grad: recompute = True if recompute or getattr(self, '_cache') is None: aff = linalg._expm(q, self.basis, grad_X=grad) else: aff = self._cache if cache_result: self._cache = aff[0] if grad else aff return aff
def _mean_space(Mat, Dim, vx=None): """Compute a (mean) model space from individual spaces. Args: Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4). Dim (torch.tensor): N subjects' dimensions (N, 3). vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input). Returns: mat (torch.tensor): Mean orientation matrix (4, 4). dim (torch.tensor): Mean dimensions (3,). vx (torch.tensor): Mean voxel size (3,). Authors: John Ashburner, as part of the SPM12 software. """ device = Mat.device dtype = Mat.dtype N = Mat.shape[0] # Number of subjects inf = float('inf') one = torch.tensor(1.0, device=device, dtype=dtype) if vx is None: vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype) if isinstance(vx, float) or isinstance(vx, int): vx = (vx, ) * 3 if isinstance(vx, tuple) and len(vx) == 3: vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype) # To float64 Mat = Mat.type(dtype) Dim = Dim.type(dtype) # Get affine basis basis = 'SE' dim = 3 if Dim[0, 2] > 1 else 2 B = affine_basis(basis, dim, device=device, dtype=dtype) # Find combination of 90 degree rotations and flips that brings all # the matrices closest to axial Mat0 = Mat.clone() pmatrix = torch.tensor( [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]], device=device) for n in range(N): # Loop over subjects vx1 = voxel_size(Mat[n, ...]) R = Mat[n, ...].mm( torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1] minss = inf minR = torch.eye(3, dtype=dtype, device=device) for i in range(6): # Permute (= 'rotate + flip') axes R1 = torch.zeros((3, 3), dtype=dtype, device=device) R1[pmatrix[i, 0], 0] = 1 R1[pmatrix[i, 1], 1] = 1 R1[pmatrix[i, 2], 2] = 1 for j in range(8): # Mirror (= 'flip') axes fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1] F = torch.diag(torch.tensor(fd, dtype=dtype, device=device)) R2 = F.mm(R1) ss = torch.sum((R.mm(R2.inverse()) - torch.eye(3, dtype=dtype, device=device))**2) if ss < minss: minss = ss minR = R2 rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1)) R2 = minR.inverse() R22 = R2.mm((torch.div( torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') - 1) * rdim) minR = torch.cat((R2, R22), dim=1) minR = torch.cat( (minR, torch.tensor([0, 0, 0, 1], device=device, dtype=dtype)[None, ...]), dim=0) Mat[n, ...] = Mat[n, ...].mm(minR) # Average of the matrices in Mat mat = meanm(Mat) # If average involves shears, then find the closest matrix that does not # require them. C_ix = torch.tensor( [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], device=device) # column-major ordering from (4, 4) tensor p = _imatrix(mat) if torch.sum(p[[9, 10, 11]]**2) > 1e-8: B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype) B2[0, 0, 0] = 1 B2[1, 1, 1] = 1 B2[2, 2, 2] = 1 p = torch.zeros(9, device=device, dtype=dtype) for n_iter in range(10000): # Rotations + Translations R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True) # Zooms Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True) M = R.mm(Z) dM = torch.zeros((4, 4, 9), device=device, dtype=dtype) for n in range(6): dM[..., n] = dR[n, ...].mm(Z) for n in range(3): dM[..., 6 + n] = R.mm(dZ[n, ...]) dM = dM.reshape((16, 9)) d = M.flatten() - mat.flatten() gr = dM.t().mm(d[..., None]) Hes = dM.t().mm(dM) p = p - lmdiv(Hes, gr)[:, 0] if torch.sum(gr**2) < 1e-8: break mat = M.clone() # Set required voxel size vx_out = vx.clone() vx = voxel_size(mat) vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)] mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag()) vx = voxel_size(mat) # Ensure that the FoV covers all images, with a few voxels to spare mn_all = torch.zeros([3, N], device=device, dtype=dtype) mx_all = torch.zeros([3, N], device=device, dtype=dtype) for n in range(N): dm = Dim[n, ...] corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]], [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]], [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]], [1, 1, 1, 1, 1, 1, 1, 1]], device=device, dtype=dtype) M = lmdiv(mat, Mat0[n]) vx1 = M[:-1, :].mm(corners) mx_all[..., n] = torch.max(vx1, dim=1)[0] mn_all[..., n] = torch.min(vx1, dim=1)[0] mx = mx_all.max(dim=1)[0] mn = mn_all.min(dim=1)[0] mx = torch.ceil(mx) mn = torch.floor(mn) # Make output dimensions and orientation matrix dim = mx - mn + 1 # Output dimensions off = torch.tensor([0, 0, 0], device=device, dtype=dtype) mat = mat.mm( torch.tensor([[1, 0, 0, mn[0] - (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)], [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]], device=device, dtype=dtype)) return mat, dim, vx
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, basis='affine', fwhm=None, joint=False, prm=None, max_iter_gn=100, max_iter_em=32, max_line_search=6, progressive=False, verbose=1): """ Parameters ---------- dat : (B, J|1, *spatial) tensor tpm : (B|1, K, *spatial) tensor affine : (4, 4) tensor affine_tpm : (4, 4) tensor weights : (B, 1, *spatial) tensor basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'} fwhm : float, default=J/32 joint : bool, default=False max_iter_gn : int, default=100 max_iter_em : int, default=32 max_line_search : int, default=12 progressive : bool, default=False Returns ------- mi : (B,) tensor aff : (B, 4, 4) tensor prm : (B, F) tensor """ dim = dat.dim() - 2 # ------------------------------------------------------------------ # RECURSIVE PROGRESSIVE FIT # ------------------------------------------------------------------ if progressive: nb_se = dim * (dim + 1) // 2 nb_aff = dim * (dim + 1) basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'} basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se} basis = convert_basis(basis) next_basis = basis_recursion.get(basis, None) if next_basis: *_, prm = fit_affine_tpm(dat, tpm, affine, affine_tpm, weights, basis=next_basis, fwhm=fwhm, joint=joint, prm=prm, max_iter_gn=max_iter_gn, max_iter_em=max_iter_em, max_line_search=max_line_search) B = len(dat) F = basis_nb_feat[basis] prm0 = prm prm = prm0.new_zeros([1 if joint else B, F]) if basis == 'SE': prm[:, :dim] = prm0[:, :dim] else: nb_se = dim * (dim + 1) // 2 prm[:, :nb_se] = prm0[:, :nb_se] if basis == 'Aff+': prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5)) basis_name = basis # ------------------------------------------------------------------ # PREPARE # ------------------------------------------------------------------ B = len(dat) if affine is None: affine = spatial.affine_default(dat.shape[-dim:]) if affine_tpm is None: affine_tpm = spatial.affine_default(tpm.shape[-dim:]) affine = affine.to(**utils.backend(tpm)) affine_tpm = affine_tpm.to(**utils.backend(tpm)) shape = dat.shape[-dim:] tpm = tpm.to(dat.device) basis = make_basis(basis, dim, **utils.backend(tpm)) F = len(basis) if prm is None: prm = tpm.new_zeros([1 if joint else B, F]) aff, gaff = linalg._expm(prm, basis, grad_X=True) em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights, verbose=verbose - 2) drv_opt = dict(weights=weights) pull_opt = dict(bound='replicate', extrapolate=True) # ------------------------------------------------------------------ # OPTIMIZE # ------------------------------------------------------------------ prior = None mi = torch.as_tensor(-float('inf')) delta = torch.zeros_like(prm) for n_iter in range(max_iter_gn): # -------------------------------------------------------------- # LINE SEARCH # -------------------------------------------------------------- prior0, prm0, mi0 = prior, prm, mi armijo = 1 success = False for n_ls in range(max_line_search): # --- take a step ------------------------------------------ prm = prm0 - armijo * delta # --- build transformation field --------------------------- aff, gaff = linalg._expm(prm, basis, grad_X=True) phi = lmdiv(affine_tpm, mm(aff, affine)) phi = spatial.affine_grid(phi, shape) # --- warp TPM --------------------------------------------- mov = spatial.grid_pull(tpm, phi, **pull_opt) # --- mutual info ------------------------------------------ mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt) mi = mi / Nm success = mi.sum() > mi0.sum() if verbose >= 2: end = '\n' if verbose >= 3 else '\r' happy = ':D' if success else ':(' print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}', end=end) if success: break armijo *= 0.5 # if verbose == 2: # print('') # -------------------------------------------------------------- # DID IT WORK? # -------------------------------------------------------------- if not success: prior, prm, mi = prior0, prm0, mi0 break # DEBUG # plot_registration(dat, mov, f'{basis_name} | {n_iter}') space = ' ' * max(0, 6 - len(basis_name)) if verbose >= 1: end = '\n' if verbose >= 2 else '\r' print( f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end) if mi.mean() - mi0.mean() < 1e-5: break # -------------------------------------------------------------- # GAUSS-NEWTON # -------------------------------------------------------------- # --- derivatives ---------------------------------------------- g, h = derivatives_intensity(mov, dat, prior, **drv_opt) # --- chain rule ----------------------------------------------- gmov = spatial.grid_grad(tpm, phi, **pull_opt) if joint and len(mov) == 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) else: gmov = gmov.expand([B, *gmov.shape[1:]]) gaff = lmdiv(affine_tpm, mm(gaff, affine)) g, h = chain_rule(g, h, gmov, gaff, maj=False) del gmov if joint and len(g) > 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) # --- Gauss-Newton --------------------------------------------- delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1) if verbose == 1: print('') return mi, aff, prm