def _process_reg(dat, mat, mat_a, mat_fix, dim_fix, write): """Process registration results. """ N = len(dat) rdat = torch.zeros((N, ) + dim_fix, dtype=dat[0].dtype, device=dat[0].device) for n in range(N): # loop over input images if torch.all(mat_a[n] - torch.eye(4, device=mat_a[n].device) == 0): rdat[n] = dat[n] else: mat_r = lmdiv(mat[n], mat_a[n].mm(mat_fix)) rdat[n] = _reslice_dat_3d(dat[n], mat_r, dim_fix) if write == 'reslice': dat[n] = rdat[n] mat[n] = mat_fix elif write == 'affine': mat[n] = lmdiv(mat_a[n], mat[n]) # Write output to disk? if write in ['reslice', 'affine']: write = True else: write = False return dat, mat, write, rdat
def _msk_fov(dat, mat, mat0, dim0): """Mask field-of-view (FOV) of image data according to other image's FOV. Parameters ---------- dat : (X, Y, Z), tensor Image data. mat : (4, 4), tensor Image's affine. mat0 : (4, 4), tensor Other image's affine. dim0 : (3, ), list/tuple Other image's dimensions. Returns ------- dat : (X, Y, Z), tensor Masked image data. """ dim = dat.shape M = lmdiv(mat0, mat) # mat0\mat1 grid = affine_grid(M, dim) msk = (grid[..., 0] >= 1) & (grid[..., 0] <= dim0[0]) & \ (grid[..., 1] >= 1) & (grid[..., 1] <= dim0[1]) & \ (grid[..., 2] >= 1) & (grid[..., 2] <= dim0[2]) dat[~msk] = 0 return dat
def _imatrix(M): """Return the parameters for creating an affine transformation matrix. Args: mat (torch.tensor): Affine transformation matrix (4, 4). Returns: P (torch.tensor): Affine parameters (<=12). Authors: John Ashburner & Stefan Kiebel, as part of the SPM12 software. """ device = M.device dtype = M.dtype one = torch.tensor(1.0, device=device, dtype=dtype) # Translations and Zooms R = M[:-1, :-1] C = cholesky(R.t().mm(R)) C = C.t() d = torch.diag(C) P = torch.tensor( [M[0, 3], M[1, 3], M[2, 3], 0, 0, 0, d[0], d[1], d[2], 0, 0, 0], device=device, dtype=dtype) if R.det() < 0: # Fix for -ve determinants P[6] = -P[6] # Shears C = lmdiv(torch.diag(torch.diag(C)), C) P[9] = C[0, 1] P[10] = C[0, 2] P[11] = C[1, 2] R0 = affine_matrix_classic( torch.tensor([0, 0, 0, 0, 0, 0, P[6], P[7], P[8], P[9], P[10], P[11]])).to(device) R0 = R0[:-1, :-1] R1 = R.mm(R0.inverse()) # This just leaves rotations in matrix R1 # Correct rounding errors rang = lambda x: torch.min(torch.max(x, -one), one) P[4] = torch.asin(rang(R1[0, 2])) if (torch.abs(P[4]) - pi / 2)**2 < 1e-9: P[3] = 0 P[5] = torch.atan2(-rang(R1[1, 0]), rang(-R1[2, 0] / R1[0, 2])) else: c = torch.cos(P[4]) P[3] = torch.atan2(rang(R1[1, 2] / c), rang(R1[2, 2] / c)) P[5] = torch.atan2(rang(R1[0, 1] / c), rang(R1[0, 0] / c)) return P
def _rescale(dat, mn_out=0, mx_out=511): """ Rescales image intensities between mn_out and mx_out. """ backend = dict(dtype=dat.dtype, device=dat.device) msk = torch.isfinite(dat).bitwise_not_() msk = msk.bitwise_or_(dat == dat.min()).bitwise_or_(dat == dat.max()) dat = dat.masked_fill_(msk, 0) # Make scaling to set image intensities between mn_out and mx_out mnmx_in = torch.as_tensor([[dat.min(), 1], [dat.max(), 1]], **backend) mnmx_out = torch.as_tensor([mn_out, mx_out], **backend) sf = linalg.lmdiv(mnmx_in, mnmx_out.unsqueeze(-1)).squeeze(-1) # Rescale dat = dat.mul_(sf[0]).add_(sf[1]) # Clamp dat = dat.clamp_(mn_out, mx_out) return dat
def search_direction(self, grad, hess): grad, hess = self._add_marquardt(grad, hess) step = linalg.lmdiv(hess, grad[..., None])[..., 0] step.mul_(-self.lr) return step
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def fit_curve_cat(f, s, lam=0, gamma=0, vx=1, max_iter=8, tol=1e-8, max_levels=4): """Fit the curve that maximizes the categorical likelihood Parameters ---------- f : (*shape) tensor Observed grid of binary labels or smooth probabilities. s : BSplineCurve Initial curve (will be modified in-place) Other Parameters ---------------- lam : float, default=0 Centerline regularization (bending) gamma : float, default=0 Radius regularization (membrane) vx : float, default=1 Voxel size max_iter : int, default=128 Maximum number of iterations per level (This will me multiplied by 2 at each resolution level, such that more iterations are used at coarser levels). tol : float, default=1e-8 Unused max_levels : int, default=4 Number of multi-resolution levels. Returns ------- s : BSplineCurve Fitted curve """ TINY = 1e-6 fig = elem = None max_iter_position = 8 max_iter_radius = 4 backend = utils.backend(s.coeff) vx = utils.make_vector(vx, f.dim(), **backend) vx0 = vx.clone() n0 = f.numel() # Build pyramid by restriction shapes = [f.shape] images = [f] vxs = [vx] for n_level in range(max_levels - 1): shape = [pymath.ceil(s / 2) for s in shapes[-1]] if all(s == 1 for s in shape): break shapes.append(shape) images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1)) s.restrict(shapes[-2], shapes[-1]) vx = vx * (torch.as_tensor(shapes[-2], **backend) / torch.as_tensor(shapes[-1], **backend)) vxs.append(vx) start = time.time() shape = None level = len(images) + 1 while images: level -= 1 print('-' * 16, 'level', level, '-' * 16) if shape is not None: s.prolong(shape, shapes[-1]) f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1) scl = vx.prod() / vx0.prod() x = identity_grid(f.shape, **backend) if lam: L = lam * bending3(len(s.coeff), **backend) reg = L.matmul(s.coeff).mul_(vx.square()) reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double) else: reg = 0 if gamma: Lr = gamma * membrane3(len(s.coeff_radius), **backend) Lr /= vx.prod().pow_(1 / len(vx)).square_() reg_radius = Lr.matmul(s.coeff_radius) reg_radius = 0.5 * (s.coeff_radius * reg_radius).sum(dtype=torch.double) else: reg_radius = 0 def get_nll(e): ie = (1 - e).log() e = e.log() if f.dtype is torch.bool: ll = e[f].sum(dtype=torch.double) + ie[~f].sum( dtype=torch.double) else: ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum( dtype=torch.double) ll = -ll return ll nll = float('inf') max_iter_level = max_iter * 2**((level - 1) // 2) for n_iter in range(max_iter_level): nll0_prev = nll for n_iter_position in range(max_iter_position): t, d = min_dist(x, s) p = s.eval_position(t).sub_(x) # residuals r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r, tiny=TINY) nll_prev = nll nll = get_nll(e) prec = radius_to_prec(r) # gradient of the categorical term omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not() ome = (1 - e) g = (omf / ome - 1) * (-prec) h = omf * e / ome.square() g = g.unsqueeze(-1) h = h.unsqueeze(-1) prec = prec.unsqueeze(-1) acc = 0.5 h = h * (prec * p).square() if acc != 1: h += (1 - acc) * g.abs() g = g * p # push g = s.push_position(g, t) h = s.push_position(h, t) # resolution scale g *= scl h *= scl nll *= scl # regularisation + solve if lam: reg = L.matmul(s.coeff).mul_(vx.square()) g += reg reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double) # h += L[1, :].abs().sum() g = torch.stack([ linalg.lmdiv(h1.diag() + (v1 * v1) * L, g1[:, None])[:, 0] for v1, g1, h1 in zip(vx, g.T, h.T) ], -1) else: g.div_(h) reg = 0 s.coeff.sub_(g) # s.coeff.clamp_min_(0) # for d, sz in enumerate(f.shape): # s.coeff[:, d].clamp_max_(sz-1) fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints, fig, elem) nll = nll + reg + reg_radius print('position', n_iter, n_iter_position, nll.item(), (nll_prev - nll).item() / n0) s.update_waypoints() # if nll_prev - nll < tol * f.numel(): # break if level < 3: max_iter_radius_level = max_iter_radius else: max_iter_radius_level = 0 for n_iter_radius in range(max_iter_radius_level): alpha = (2.355 / 2)**2 t, d = min_dist(x, s) r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r, TINY) d = d.square_() nll_prev = nll nll = get_nll(e) # gradient of the categorical term omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not() ome = (1 - e) alpha = alpha * d / r.pow(3) g = (omf / ome - 1) * alpha acc = 0 h = omf * e / ome.square() h *= alpha.square() if acc != 1: h += (1 - acc) * g.abs() * 3 / r # push g = s.push_radius(g, t) h = s.push_radius(h, t) # resolution scale g *= scl h *= scl nll *= scl # regularisation + solve if gamma: reg_radius = Lr.matmul(s.coeff_radius) g += reg_radius reg_radius = 0.5 * (s.coeff_radius * reg_radius).sum(dtype=torch.double) g = linalg.lmdiv(h.diag() + L, g[:, None])[:, 0] else: g.div_(h) reg_radius = 0 # solve s.coeff_radius -= g s.coeff_radius.clamp_min_(0.5) fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints, fig, elem) nll = nll + reg + reg_radius print('radius', n_iter, n_iter_radius, nll.item(), (nll_prev - nll).item() / n0) s.update_radius() # if nll_prev - nll < tol * f.numel(): # break if not n_iter % 10: print(n_iter, nll.item(), (nll0_prev - nll).item() / n0) # if nll0_prev - nll < tol * f.numel(): # print('Converged') # break stop = time.time() print(stop - start)
def _mean_space(Mat, Dim, vx=None): """Compute a (mean) model space from individual spaces. Args: Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4). Dim (torch.tensor): N subjects' dimensions (N, 3). vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input). Returns: mat (torch.tensor): Mean orientation matrix (4, 4). dim (torch.tensor): Mean dimensions (3,). vx (torch.tensor): Mean voxel size (3,). Authors: John Ashburner, as part of the SPM12 software. """ device = Mat.device dtype = Mat.dtype N = Mat.shape[0] # Number of subjects inf = float('inf') one = torch.tensor(1.0, device=device, dtype=dtype) if vx is None: vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype) if isinstance(vx, float) or isinstance(vx, int): vx = (vx, ) * 3 if isinstance(vx, tuple) and len(vx) == 3: vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype) # To float64 Mat = Mat.type(dtype) Dim = Dim.type(dtype) # Get affine basis basis = 'SE' dim = 3 if Dim[0, 2] > 1 else 2 B = affine_basis(basis, dim, device=device, dtype=dtype) # Find combination of 90 degree rotations and flips that brings all # the matrices closest to axial Mat0 = Mat.clone() pmatrix = torch.tensor( [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]], device=device) for n in range(N): # Loop over subjects vx1 = voxel_size(Mat[n, ...]) R = Mat[n, ...].mm( torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1] minss = inf minR = torch.eye(3, dtype=dtype, device=device) for i in range(6): # Permute (= 'rotate + flip') axes R1 = torch.zeros((3, 3), dtype=dtype, device=device) R1[pmatrix[i, 0], 0] = 1 R1[pmatrix[i, 1], 1] = 1 R1[pmatrix[i, 2], 2] = 1 for j in range(8): # Mirror (= 'flip') axes fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1] F = torch.diag(torch.tensor(fd, dtype=dtype, device=device)) R2 = F.mm(R1) ss = torch.sum((R.mm(R2.inverse()) - torch.eye(3, dtype=dtype, device=device))**2) if ss < minss: minss = ss minR = R2 rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1)) R2 = minR.inverse() R22 = R2.mm((torch.div( torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') - 1) * rdim) minR = torch.cat((R2, R22), dim=1) minR = torch.cat( (minR, torch.tensor([0, 0, 0, 1], device=device, dtype=dtype)[None, ...]), dim=0) Mat[n, ...] = Mat[n, ...].mm(minR) # Average of the matrices in Mat mat = meanm(Mat) # If average involves shears, then find the closest matrix that does not # require them. C_ix = torch.tensor( [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], device=device) # column-major ordering from (4, 4) tensor p = _imatrix(mat) if torch.sum(p[[9, 10, 11]]**2) > 1e-8: B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype) B2[0, 0, 0] = 1 B2[1, 1, 1] = 1 B2[2, 2, 2] = 1 p = torch.zeros(9, device=device, dtype=dtype) for n_iter in range(10000): # Rotations + Translations R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True) # Zooms Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True) M = R.mm(Z) dM = torch.zeros((4, 4, 9), device=device, dtype=dtype) for n in range(6): dM[..., n] = dR[n, ...].mm(Z) for n in range(3): dM[..., 6 + n] = R.mm(dZ[n, ...]) dM = dM.reshape((16, 9)) d = M.flatten() - mat.flatten() gr = dM.t().mm(d[..., None]) Hes = dM.t().mm(dM) p = p - lmdiv(Hes, gr)[:, 0] if torch.sum(gr**2) < 1e-8: break mat = M.clone() # Set required voxel size vx_out = vx.clone() vx = voxel_size(mat) vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)] mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag()) vx = voxel_size(mat) # Ensure that the FoV covers all images, with a few voxels to spare mn_all = torch.zeros([3, N], device=device, dtype=dtype) mx_all = torch.zeros([3, N], device=device, dtype=dtype) for n in range(N): dm = Dim[n, ...] corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]], [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]], [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]], [1, 1, 1, 1, 1, 1, 1, 1]], device=device, dtype=dtype) M = lmdiv(mat, Mat0[n]) vx1 = M[:-1, :].mm(corners) mx_all[..., n] = torch.max(vx1, dim=1)[0] mn_all[..., n] = torch.min(vx1, dim=1)[0] mx = mx_all.max(dim=1)[0] mn = mn_all.min(dim=1)[0] mx = torch.ceil(mx) mn = torch.floor(mn) # Make output dimensions and orientation matrix dim = mx - mn + 1 # Output dimensions off = torch.tensor([0, 0, 0], device=device, dtype=dtype) mat = mat.mm( torch.tensor([[1, 0, 0, mn[0] - (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)], [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]], device=device, dtype=dtype)) return mat, dim, vx
def slice_correct(x, dim=-1, nb_iter=20): n = x.shape[dim] x = utils.movedim(x, dim, -1) shape = x.shape x = x.reshape([-1, n]) vmax = x.max() m = x > 0 x = x.log() x[~m] = 0 a = x.new_zeros([n]) g = x.new_zeros([n]) h = x.new_zeros([n, n]) for i in range(nb_iter): # compute forward differences d = a + x d = d[:, 1:] - d[:, :-1] d[~(m[:, 1:] & m[:, :-1])] = 0 w = d.abs() print(w.mean().item()) import matplotlib.pyplot as plt plt.subplot(1, 2, 1) plt.imshow((a + x).reshape(shape)[shape[0] // 2].exp(), vmin=0, vmax=vmax) plt.colorbar() plt.subplot(1, 2, 2) plt.imshow(w.reshape([*shape[:-1], n - 1]).mean(0)) plt.colorbar() plt.show() w = w.clamp_min_(1e-5).reciprocal_() w[~(m[:, 1:] & m[:, :-1])] = 0 # compute gradient g.zero_() g = g.reshape([n]) g[1:] = (w * d).sum(0) g[:-1] -= (w * d).sum(0) # compute hessian h.zero_() h.diagonal(0, -1, -2)[1:] = w.sum(0) h.diagonal(0, -1, -2)[:-1] += w.sum(0) h.diagonal(1, -1, -2)[:] = -w.sum(0) h.diagonal(-1, -1, -2)[:] = h.diagonal(1, -1, -2) h = h.reshape([n, n]) g = g.reshape([n, 1]) g /= len(x) h /= len(x) h.diagonal(0, -1, -2).add_(h.diagonal(0, -1, -2).max() * 1e-6) a -= linalg.lmdiv(h, g).reshape([n]) # zero center a -= a.mean() x = (a + x).exp() x = x.reshape(shape).movedim(-1, dim) return x, a
def shim(fmap, max_order=2, mask=None, isocenter=None, dim=None, returns='corrected'): """Subtract a linear combination of spherical harmonics that minimize gradients Parameters ---------- fmap : (..., *spatial) tensor Field map max_order : int, default=2 Maximum order of the spherical harmonics mask : tensor, optional Mask of voxels to include (typically brain mask) isocenter : [sequence of] float, default=shape/2 Coordinate of isocenter, in voxels dim : int, default=fmap.dim() Number of spatial dimensions returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected' Components to return Returns ------- corrected : (..., *spatial) tensor, if 'corrected' in `returns` Corrected field map (with spherical harmonics subtracted) correction : (..., *spatial) tensor, if 'correction' in `returns` Linear combination of spherical harmonics. parameters : (..., k) tensor, if 'parameters' in `returns` Parameters of the linear combination """ fmap = torch.as_tensor(fmap) dim = dim or fmap.dim() shape = fmap.shape[-dim:] batch = fmap.shape[:-dim] backend = utils.backend(fmap) dims = list(range(-dim, 0)) if mask is not None: mask = ~mask # make it a mask of background voxels # compute gradients gmap = diff(fmap, dim=dims, side='f', bound='dct2') if mask is not None: gmap[..., mask, :] = 0 gmap = gmap.reshape([*batch, -1]) # compute basis of spherical harmonics basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = diff(b, dim=dims, side='f', bound='dct2') if mask is not None: b[..., mask, :] = 0 b = b.reshape([b.shape[0], *batch, -1]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) # solve system prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0] # > (*batch, k) # rebuild basis (without taking gradients) basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = b.reshape([b.shape[0], *batch, *shape]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim)) comb = comb[..., 0] fmap = fmap - comb returns = returns.split('+') out = [] for ret in returns: if ret == 'corrected': out.append(fmap) elif ret == 'correction': out.append(comb) elif ret[0] == 'p': out.append(prm) return out[0] if len(out) == 1 else tuple(out)
def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: aff00, gaff00 = iaff0, igaff0 else: aff00, gaff00 = aff0, gaff0 # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- aff = aff00 @ fixed.affine aff = linalg.lmdiv(moving.affine, aff) gaff = gaff00 @ fixed.affine gaff = linalg.lmdiv(moving.affine, gaff) phi = spatial.affine_grid(aff, fixed.shape) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.preview else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.preview, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h # compose with spatial gradients if grad or hess: mugrad = moving.pull_grad(phi, rotate=False) g, h = compose_grad(g, h, mugrad, gaff) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla lla = lla ll = sumloss.item() self.loss_value = ll if self.verbose and (self.verbose > 1 or not in_line_search): if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def do_vel(self, vel, grad=False, hess=False, in_line_search=False): """Forward pass for updating the nonlinear component""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== if self.affine: aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False) aff_pos = self.affine.position[0].lower() else: aff_pos = 'x' aff0 = iaff0 = torch.eye(self.nonlin.dim + 1) vel0 = vel if any(loss.backward for loss in self.losses): phi0, iphi0 = self.nonlin.exp2(vel0, recompute=True, cache_result=not in_line_search) ivel0 = -vel0 else: phi0 = self.nonlin.exp(vel0, recompute=True, cache_result=not in_line_search) iphi0 = ivel0 = None aff0 = aff0.to(phi0) iaff0 = iaff0.to(phi0) # ============================================================== # ACCUMULATE DERIVATIVES # ============================================================== has_printed = False for loss in self.losses: # ========================================================== # ONE LOSS COMPONENT # ========================================================== moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, vel00 = iphi0, iaff0, ivel0 else: phi00, aff00, vel00 = phi0, aff0, vel0 # ---------------------------------------------------------- # build left and right affine # ---------------------------------------------------------- aff_right = fixed.affine if aff_pos in 'fs': # affine position: fixed or symmetric aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left = self.nonlin.affine if aff_pos in 'ms': # affine position: moving or symmetric aff_left = aff00 @ self.nonlin.affine aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: aff_right = None phi = spatial.add_identity_grid(phi00) disp = phi00 else: phi = spatial.affine_grid(aff_right, fixed.shape) disp = regutils.smart_pull_grid(phi00, phi) phi += disp if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: aff_left = None else: phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim, title=f'(nonlin) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt phi # ---------------------------------------------------------- if grad or hess: g, h, mugrad = self.nonlin.propagate_grad( g, h, moving, phi00, aff_left, aff_right, inv=loss.backward) g = regutils.jg(mugrad, g) h = regutils.jhj(mugrad, h) if isinstance(self.nonlin, SVFModel): # propagate backward by scaling and squaring g, h = spatial.exp_backward(vel00, g, h, steps=self.nonlin.steps) sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # ============================================================== # REGULARIZATION # ============================================================== vgrad = self.nonlin.regulariser(vel0) llv = 0.5 * vel0.flatten().dot(vgrad.flatten()) if grad: sumgrad += vgrad del vgrad # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += llv sumloss += self.lla self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): llv = llv.item() ll = sumloss.item() lla = self.lla if in_line_search: line = '(search) | ' else: line = '(nonlin) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.llv = llv self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def _compute_cost(q, grid0, dat_fix, mat_fix, dat, mat, mov, cost_fun, B, mx_int, fwhm, return_res=False): """Compute registration cost function. Parameters ---------- q : (N, Nq) tensor_like Lie algebra of affine registration fit. grid0 : (X1, Y1, Z1) tensor_like Sub-sampled image data's resampling grid. dat_fix : (X1, Y1, Z1) tensor_like Fixed image data. mat_fix : (4, 4) tensor_like Fixed affine matrix. dat : [N,] tensor_like List of input images. mat : [N,] tensor_like List of affine matrices. mov : [N,] int Indices of moving images. cost_fun : str Cost function to compute (see run_affine_reg). B : (Nq, N, N) tensor_like Affine basis. mx_int : int This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=511 -> H.shape = (512, 512)). fwhm : float Full-width at half max of Gaussian kernel, for smoothing histogram. return_res : bool, default=False Return registration results for plotting. Returns ---------- c : float Cost of aligning images with current estimate of q. If optimiser='powell', array_like, else tensor_like. res : tensor_like Registration results, for visualisation (only if return_res=True). """ # Init device = grid0.device q = q.flatten() was_numpy = False if isinstance(q, np.ndarray): was_numpy = True q = torch.from_numpy(q).to(device) # To torch tensor dm_fix = dat_fix.shape Nq = B.shape[0] N = torch.tensor(len(dat), device=device, dtype=torch.float32) # For modulating NJTV cost if cost_fun in _costs_edge: jtv = dat_fix.clone() if cost_fun == 'njtv': njtv = -dat_fix.sqrt() for i, m in enumerate(mov): # Loop over moving images # Get affine matrix mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B) # Compose matrices M = lmdiv(mat[m], mat_a.mm(mat_fix)).to(grid0.dtype) # mat_mov\mat_a*mat_fix # Transform fixed grid grid = affine_matvec(M, grid0) # Resample to fixed grid dat_new = grid_pull(dat[m], grid, bound='dft', extrapolate=False, interpolation=1) if cost_fun in _costs_edge: jtv += dat_new if cost_fun == 'njtv': njtv -= dat_new.sqrt() # Compute the cost function res = None if cost_fun in _costs_hist: # Histogram based costs # ---------- # Compute joint histogram # OBS: This function expects both images to have the same max and min intesities, # this is ensured by the _data_loader() function. H = _hist_2d(dat_fix, dat_new, mx_int, fwhm) res = H # Get probabilities pxy = H / H.sum() px = pxy.sum(dim=0, keepdim=True) py = pxy.sum(dim=1, keepdim=True) # Compute cost if cost_fun == 'mi': # Mutual information mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) c = -mi elif cost_fun == 'ecc': # Entropy Correlation Coefficient # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997). # "Multimodality image registration by maximisation of mutual # information". IEEE Transactions on Medical Imaging 16(2):187-198 mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) ecc = -2 * mi / (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) c = -ecc elif cost_fun == 'nmi': # Normalised Mutual Information # Studholme, Hill & Hawkes (1998). # "A normalized entropy measure of 3-D medical image alignment". # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143. nmi = (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2()) c = -nmi elif cost_fun == 'ncc': # Normalised Cross Correlation i = torch.arange(1, pxy.shape[0] + 1, device=device, dtype=torch.float32) j = torch.arange(1, pxy.shape[1] + 1, device=device, dtype=torch.float32) m1 = torch.sum(py * i[..., None]) m2 = torch.sum(px * j[None, ...]) sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2)) sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2)) i, j = torch.meshgrid(i - m1, j - m2) ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2) c = -ncc elif cost_fun in _costs_edge: # Normalised Joint Total Variation # M Brudfors, Y Balbastre, J Ashburner (2020). # "Groupwise Multimodal Image Registration using Joint Total Variation". # in MIUA 2020. jtv.sqrt_() if cost_fun == 'njtv': njtv += torch.sqrt(N) * jtv res = njtv c = torch.sum(njtv) else: res = jtv c = torch.sum(jtv) # _ = show_slices(res, fig_num=1, cmap='coolwarm') # Can be uncommented for testing if was_numpy: # Back to numpy array c = c.cpu().numpy() if return_res: return c, res else: return c
def kernel_fit(calib, kernel_size, patterns, lam=0.01): """Compute GRAPPA kernels All batch elements should have the same sampling pattern Parameters ---------- calib : ([*batch], coils, *freq) Fully-sampled calibration data kernel_size : sequence[int] GRAPPA kernel size patterns : (N,) tensor[int] Code of patterns for which to learn a kernel. See `pattern_to_code`. lam : float, default=0.01 Tikhonov regularization Returns ------- kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor GRAPPA kernels """ kernel_size = py.make_list(kernel_size) ndim = len(kernel_size) coils, *freq = calib.shape[-ndim - 1:] batch = calib.shape[:-ndim - 1] # find all possible patterns patterns = utils.as_tensor(patterns, device=calib.device) if patterns.dtype is torch.bool: patterns = pattern_to_code(patterns, ndim) patterns = patterns.flatten() # learn one kernel for each pattern calib = utils.unfold(calib, kernel_size, collapse=True) # [*B, C, N, *K] calib = utils.movedim(calib, -ndim - 1, -ndim - 2) # [*B, N, C, *K] def t(x): return x.transpose(-1, -2) def conjt(x): return t(x).conj() def diag(x): return x.diagonal(0, -1, -2) kernels = {} center = [(k - 1) // 2 for k in kernel_size] center = (Ellipsis, *center) for pattern_code in patterns: if code_has_center(pattern_code, kernel_size): continue pattern = code_to_pattern(pattern_code, kernel_size, device=calib.device) pattern_size = pattern.sum() if pattern_size == 0: continue calib_target = calib[center] # [*B, N, C] calib_source = calib[..., pattern] # [*B, N, C, P] calib_size = calib_target.shape[-2] flat_shape = [*batch, calib_size, pattern_size * coils] calib_source = calib_source.reshape(flat_shape) # [*B, N, C*P] # solve H = conjt(calib_source).matmul(calib_source) # [*B, C*P, C*P] diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values) diag(H).add_(lam) g = conjt(calib_source).matmul(calib_target) # [*B, C*P, C] k = linalg.lmdiv(H, g).transpose(-1, -2) # [*B, C, C*P] k = k.reshape([*batch, coils, coils, pattern_size]) # [*B, C, C, P] kernels[pattern_code.item()] = k return kernels