def responsibilities(image, means, precisions, proportions): # aliases x = image m = means A = precisions p = proportions nb_dim = image.dim() - 2 del image, means, precisions, proportions # voxel-wise term x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] p = unsqueeze(p, dim=1, ndim=nb_dim) # [B, ones, K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m z = matvec(A, x) z = (z * x).sum(dim=-1) # [B, ..., K] z = -0.5 * z # constant term twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device) nrm = torch.logdet(A) - A.shape[-1] * twopi.log() nrm = 0.5 * nrm + p.log() z = z + nrm # softmax z = last2channel(z) logz = torch.nn.functional.log_softmax(z, dim=1) z = torch.nn.functional.softmax(z, dim=1) return z, logz
def forward(self, x, v=None): dim = x.dim() - 2 if dim not in (2, 3): raise ValueError(f'{type(self).__name__} only implemented ' f'in 2D or 3D.') radii = self.radii.to(**utils.backend(x)) pradii = self.pradii.to(**utils.backend(x)).log() # compute joint log-likelihood `ln p(x, radius | v)` loss = x.new_zeros([len(radii), *x.shape]) for i, (p, r) in enumerate(zip(pradii, radii)): # compute unsorted eigenvalues e = spatial.hessian_eig(x, r, dim=dim, sort=None) # soft sort P = math.softsort(e.abs(), tau=self.tau_sort, descending=True) e = linalg.matvec(P, e) e = utils.movedim(e, -1, 0) # compute penalties loss[i] = -self.tau_large * e[1:].sum(0) # white ridges e = e.square().clamp_min_(1e-32).log() if dim == 3: loss[i] += self.tau_ratio1 * (e[1] - e[2]) # tubes loss[i] += self.tau_ratio0 * (e[1] - e[0]) # not plates loss[i] += p # radius prior # compute (stable) log-sum-exp (== model evidence `ln p(x | v)`) loss = math.logsumexp(loss, dim=0) # weight by probability to be a vessel and return `E_v[ln p(x | v)]` if v is None: v = x return -(loss * v).sum() / (v.sum() + 1e-3)
def affine_grid(mat, shape): """Create a dense transformation grid from an affine matrix. Parameters ---------- mat : (..., D[+1], D[+1]) tensor Affine matrix (or matrices). shape : (D,) sequence[int] Shape of the grid, with length D. Returns ------- grid : (..., *shape, D) tensor Dense transformation grid """ mat = torch.as_tensor(mat) shape = list(shape) nb_dim = mat.shape[-1] - 1 if nb_dim != len(shape): raise ValueError('Dimension of the affine matrix ({}) and shape ({}) ' 'are not the same.'.format(nb_dim, len(shape))) if mat.shape[-2] not in (nb_dim, nb_dim + 1): raise ValueError( 'First argument should be matrces of shape ' '(..., {0}, {1}) or (..., {1], {1}) but got {2}.'.format( nb_dim, nb_dim + 1, mat.shape)) batch_shape = mat.shape[:-2] grid = identity_grid(shape, mat.dtype, mat.device) grid = utils.unsqueeze(grid, dim=0, ndim=len(batch_shape)) mat = utils.unsqueeze(mat, dim=-3, ndim=nb_dim) lin = mat[..., :nb_dim, :nb_dim] off = mat[..., :nb_dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch shape. Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- grid : (batch, *shape, 3) tensor Resampling grid """ shape = overload.get('shape', self.grid.velocity.field.shape) dtype = overload.get('dtype', self.grid.velocity.field.dtype) device = overload.get('device', self.grid.velocity.field.device) backend = dict(dtype=dtype, device=device) if self.grid.velocity.field.amplitude == 0: grid = identity_grid(shape, **backend) else: grid = self.grid(batch, shape=shape, **backend) dtype = grid.dtype device = grid.device backend = dict(dtype=dtype, device=device) shape = grid.shape[1:-1] dim = len(shape) aff = self.affine(batch, dim=dim, **backend) # shift center of rotation aff_shift = torch.cat(( torch.eye(dim, **backend), torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)), dim=1) aff_shift = as_euclidean(aff_shift) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = utils.unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def exp(self, velocity, affine=None, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field affine : (batch, nb_prm) Affine parameters displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacment). """ info = {'dtype': velocity.dtype, 'device': velocity.device} # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if affine is not None: # exponentiate affine_prm = affine affine = [] for prm in affine_prm: affine.append(self.affexp(prm)) affine = torch.stack(affine, dim=0) # shift center of rotation affine_shift = torch.cat( (torch.eye(self.dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) # compose affine = unsqueeze(affine, dim=-3, ndim=self.dim) lin = affine[..., :self.dim, :self.dim] off = affine[..., :self.dim, -1] grid = matvec(lin, grid) + off if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **info) return grid
def se_sample_svd(shape, sigma, lam, mu=None, repeats=1, **backend): """Sample random fields with a squared exponential kernel. This function computes the square root of the covariance matrix by SVD. Parameters ---------- shape : sequence[int] Shape of the image / volume.å sigma : float SE amplitude. lam : float SE length-scale. mu : () or (*shape) tensor_like SE mean repeats : int, default=1 Number of sampled fields. Returns ------- field : (repeats, *shape) tensor Sampled random fields. """ # Build SE covariance matrix e = dist_map(shape, **backend) backend = utils.backend(e) e.mul_(-0.5 / (lam**2)).exp_().mul_(sigma**2) # import matplotlib.pyplot as plt # plt.imshow(e) # plt.colorbar() # plt.title('true cov') # plt.show() # SVD of covariance u, s, _ = torch.svd(e) s = s.sqrt_() # Sample white noise and apply transform full_shape = (repeats, *shape) field = torch.randn(full_shape, **backend).reshape([repeats, -1]) field = linalg.matvec(u, field.mul_(s)) field = field.reshape(full_shape) # Add mean if mu is not None: mu = torch.as_tensor(mu, **backend) field += mu return field
def cc_sample(shape, sigma, alpha, mu=None, repeats=1, **backend): """Sample random fields with a constant correlation. This function computes the square root of the covariance matrix by SVD. Parameters ---------- shape : sequence[int] Shape of the image / volume.å sigma : float Variance. alpha : float Correlation. mu : () or (*shape) tensor_like SE mean repeats : int, default=1 Number of sampled fields. Returns ------- field : (repeats, *shape) tensor Sampled random fields. """ # Build SE covariance matrix n = py.prod(shape) e = torch.full([n, n], alpha, **backend) e.diagonal(0, -1, -2).add_(1 - alpha) backend = utils.backend(e) # SVD of covariance u, s, _ = torch.svd(e) s = s.sqrt_() # Sample white noise and apply transform full_shape = (repeats, *shape) field = torch.randn(full_shape, **backend).reshape([repeats, -1]) field = linalg.matvec(u, field.mul_(s)) field.mul_(sigma) field = field.reshape(full_shape) # Add mean if mu is not None: mu = torch.as_tensor(mu, **backend) field += mu return field
def _rotate_grad(grad, aff=None, dense=None): """Rotate grad by the jacobian of `aff o dense`. grad : (..., dim) tensor Spatial gradients aff : (dim+1, dim+1) tensor Affine matrix dense : (..., dim) tensor Dense vox2vox displacement field returns : (..., dim) tensor Rotated gradients. """ if aff is None and dense is None: return grad dim = grad.shape[-1] if dense is not None: jac = spatial.grid_jacobian(dense, type='disp') if aff is not None: jac = torch.matmul(aff[:dim, :dim], jac) else: jac = aff[:dim, :dim] grad = linalg.matvec(jac.transpose(-1, -2), grad) return grad
def nll(image, resp, means, precisions): # aliases x = image z = resp m = means A = precisions nb_dim = image.dim() - 2 del image, resp, means, precisions x = channel2last(x).unsqueeze(-2) # [B, ..., 1, C] z = channel2last(z) # [B, ..., K] m = unsqueeze(m, dim=1, ndim=nb_dim) # [B, ones, K, C] A = unsqueeze(A, dim=1, ndim=nb_dim) # [B, ones, K, C, C] x = x - m loss = matvec(A, x) loss = (loss * x).sum(dim=-1) # [B, ..., K] loss = (loss * z).sum(dim=-1) # [B, ...] loss = loss * 0.5 return loss
def jg(jac, grad, dim=None): """Jacobian-gradient product: J*g Parameters ---------- jac : (..., K, *spatial, D) grad : (..., K, *spatial) Returns ------- new_grad : (..., *spatial, D) """ if grad is None: return None dim = dim or (grad.dim() - 1) grad = utils.movedim(grad, -dim - 1, -1) jac = utils.movedim(jac, -dim - 2, -1) grad = linalg.matvec(jac, grad) return grad
def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h
def pull_grad(self, grid, rotate=False): """Sample the image gradients at dense coordinates. Parameters ---------- grid : (*spatial, dim) tensor or None Dense transformation field. rotate : bool, default=False Rotate the gradients using the Jacobian of the transformation. Returns ------- grad : ([C], *spatial, dim) tensor """ if grid is None: return self.grad() grad = spatial.grid_grad(self.dat, grid, bound=self.bound, extrapolate=self.extrapolate) if rotate: jac = spatial.grid_jacobian(grid) jac = jac.transpose(-1, -2) grad = linalg.matvec(jac, grad) return grad
def exp_backward(vel, *grad_and_hess, inverse=False, steps=8, interpolation='linear', bound='dft', rotate_grad=False): """Backward pass of SVF exponentiation. This should be much more memory-efficient than the autograd pass as we don't have to store intermediate grids. I am using DARTEL's derivatives (from the code, not the paper). From what I get, it corresponds to pushing forward the gradient (computed in observation space) recursively while squaring the (inverse) transform. Remember that the push forward of g by phi is |iphi| iphi' * g(iphi) where iphi is the inverse of phi. We could also have implemented this operation as: inverse(phi)' * push(g, phi), since push(g, phi) \approx |iphi| g(iphi). It has the advantage of using push rather than pull, which might preserve better positive-definiteness of the Hessian, but requires the inversion of (potentially ill-behaved) Jacobian matrices. Note that gradients must first be rotated using the Jacobian of the exponentiated transform so that the denominator refers to the initial velocity (we want dL/dV0, not dL/dPsi). THIS IS NOT DONE INSIDE THIS FUNCTION YET (see _dartel). Parameters ---------- vel : (..., *spatial, dim) tensor Velocity grad : (..., *spatial, dim) tensor Gradient with respect to the output grid hess : (..., *spatial, dim*(dim+1)//2) tensor, optional Symmetric hessian with respect to the output grid. inverse : bool, default=False Whether the grid is an inverse steps : int, default=8 Number of scaling and squaring steps interpolation : str or int, default='linear' bound : str, default='dft' rotate_grad : bool, default=False If True, rotate the gradients using the Jacobian of exp(vel). Returns ------- grad : (..., *spatial, dim) tensor Gradient with respect to the SVF hess : (..., *spatial, dim*(dim+1)//2) tensor, optional Approximate (block diagonal) Hessian with respect to the SVF """ has_hess = len(grad_and_hess) > 1 grad, *hess = grad_and_hess hess = hess[0] if hess else None del grad_and_hess opt = dict(bound=bound, interpolation=interpolation) dim = vel.shape[-1] shape = vel.shape[-dim - 1:-1] id = identity_grid(shape, **utils.backend(vel)) vel = vel.clone() if rotate_grad: # It forces us to perform a forward exponentiation, which # is a bit annoying... # Maybe save the Jacobian after the forward pass? But it take space _, jac = exp_forward(vel, jacobian=True, steps=steps, displacement=True, **opt, _anagrad=True) jac = jac.transpose(-1, -2) grad = linalg.matvec(jac, grad) if hess is not None: hess = _jhj(jac, hess) del jac vel /= (-1 if not inverse else 1) * (2**steps) jac = grid_jacobian(vel, bound=bound, type='disp') for _ in range(steps): det = jac.det() jac = jac.transpose(-1, -2) grad0 = grad grad = _pull_vel(grad, id + vel, **opt) # \ grad = linalg.matvec(jac, grad) # | push forward grad *= det[..., None] # / grad += grad0 # add all scales (SVF) if hess is not None: hess0 = hess hess = _pull_vel(hess, id + vel, **opt) hess = _jhj(jac, hess) hess *= det[..., None] hess += hess0 # squaring jac = jac.transpose(-1, -2) jac = _composition_jac(jac, vel, type='disp', identity=id, **opt) vel += _pull_vel(vel, id + vel, **opt) if inverse: grad.neg_() grad /= (2**steps) if hess is not None: hess /= (2**steps) return (grad, hess) if has_hess else grad
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def shoot(vel, greens=None, absolute=_default_absolute, membrane=_default_membrane, bending=_default_bending, lame=_default_lame, factor=1, voxel_size=1, return_inverse=False, displacement=False, steps=8, fast=True, verbose=False): """Exponentiate a velocity field by geodesic shooting. Notes ----- .. This function generates the *inverse* deformation, if we follow LDDMM conventions. This allows the velocity to be defined in the space of the moving image. .. If the greens function is provided, the penalty parameters are not used. Parameters ---------- vel : ([batch], *spatial, dim) tensor Initial velocity in moving space. greens : (*spatial, [dim, dim]) tensor, optional Greens function generated by `greens`. absolute : float, default=0.0001 Penalty on absolute displacements. membrane : float, default=0.001 Penalty on the membrane energy. bending : float, default=0.2 Penalty on the bending energy. lame : float or (float, float), default=(0.05, 0.2) Linear elastic penalty. voxel_size : [sequence of] float, default=1 Needed when greens is provided if lame == 0. return_inverse : bool, default=False Return the inverse on top of the forward transform. displacement : bool, default=False Return a displacement field instead of a transformation field. steps : int, default=8 Number of integration steps. If None, use an educated guess based on the magnitude of `vel`. fast : bool, default=True If True, use a faster integration scheme, which may induce some numerical error (the energy is not exactly preserved along time). Else, use the slower but more precise scheme. Returns ------- grid : ([batch], *spatial, dim) tensor Transformation from fixed to moving space. (It is used to warp a moving image to a fixed one). igrid : ([batch], *spatial, dim) tensor, if return_inverse Inverse transformation, from fixed to moving space. (It is used to warp a fixed image to a moving one). """ # Authors # ------- # .. John Ashburner <*****@*****.**> : original Matlab code # .. Yael Balbastre <*****@*****.**> : Python port # # License # ------- # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm) # under the GNU General Public Licence (version >= 2). vel = torch.as_tensor(vel) backend = utils.backend(vel) dim = vel.shape[-1] spatial = vel.shape[-dim - 1:-1] prm = dict(absolute=absolute, membrane=membrane, bending=bending, lame=lame, voxel_size=voxel_size, factor=factor) pull_prm = dict(bound='dft', interpolation=1, extrapolate=True) if greens is None: greens = _greens(spatial, **prm, **backend) greens = torch.as_tensor(greens, **backend) if not steps: # Number of time steps from an educated guess about how far to move with torch.no_grad(): steps = vel.square().sum( dim=-1).max().sqrt().floor().int().item() + 1 id = identity_grid(spatial, **backend) mom = mom0 = regulariser_grid(vel, **prm, bound='dft') vel = vel / steps disp = -vel if return_inverse or not fast: idisp = vel.clone() for i in range(1, abs(steps)): if fast: # JA: the update of u_t is not exactly as described in the paper, # but describing this might be a bit tricky. The approach here # was the most stable one I could find - although it does lose some # energy as < v_t, u_t> decreases over time steps. jac = _jacobian(-vel) mom = linalg.matvec(jac.transpose(-1, -2), mom) mom = _push_grid(mom, id + vel, **pull_prm) else: jac = _jacobian(idisp).inverse() mom = linalg.matvec(jac.transpose(-1, -2), mom0) mom = _push_grid(mom, id + idisp, **pull_prm) # Convolve with Greens function of L # v_t \gets L^g u_t vel = greens_apply(mom, greens, factor=factor, voxel_size=voxel_size) vel = vel.div_(steps) if verbose: print(f'{0.5*steps*(vel*mom).sum().item()/py.prod(spatial):6g}', end='\n' if not (i % 5) else ' ', flush=True) # $\psi \gets \psi \circ (id - \tfrac{1}{T} v)$ # JA: I found that simply using # $\psi \gets \psi - \tfrac{1}{T} (D \psi) v$ was not so stable. disp = _pull_grid(disp, id - vel, **pull_prm).sub_(vel) if return_inverse or not fast: idisp += _pull_grid(vel, id + idisp, **pull_prm) if verbose: print('') if not displacement: disp += id if return_inverse: idisp += id return (disp, idisp) if return_inverse else disp
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) voxel_size = overload.get('voxel_size', self.voxel_size) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) voxel_size = voxel_size.tolist() lame = py.make_list(self.lame, 2) if (hasattr(self, '_greens') and self._voxel_size == voxel_size and self._shape == shape): greens = self._greens.to(dtype=dtype, device=device) else: greens = spatial.greens( shape, absolute=self.absolute, membrane=self.membrane, bending=self.bending, lame=self.lame, voxel_size=voxel_size, device=device, dtype=dtype) if any(lame): greens, scale, _ = torch.svd(greens) scale = scale.sqrt_() greens *= scale.unsqueeze(-1) else: greens = greens.sqrt_() if self.cache_greens: self._greens = greens self._voxel_size = voxel_size self._shape = shape sample = torch.randn([2, batch, *shape, nb_dim], **backend) # multiply by square root of greens if greens.dim() > nb_dim: # lame sample = linalg.matvec(greens, sample) else: sample = sample * greens.unsqueeze(-1) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) sample = sample / voxel_size.sqrt() sample = fft.complex(sample[0], sample[1]) # inverse Fourier transform dims = list(range(-nb_dim-1, -1)) sample = fft.real(fft.ifftn(sample, dim=dims)) sample *= py.prod(shape) # add mean sample += mean return sample
def shim(fmap, max_order=2, mask=None, isocenter=None, dim=None, returns='corrected'): """Subtract a linear combination of spherical harmonics that minimize gradients Parameters ---------- fmap : (..., *spatial) tensor Field map max_order : int, default=2 Maximum order of the spherical harmonics mask : tensor, optional Mask of voxels to include (typically brain mask) isocenter : [sequence of] float, default=shape/2 Coordinate of isocenter, in voxels dim : int, default=fmap.dim() Number of spatial dimensions returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected' Components to return Returns ------- corrected : (..., *spatial) tensor, if 'corrected' in `returns` Corrected field map (with spherical harmonics subtracted) correction : (..., *spatial) tensor, if 'correction' in `returns` Linear combination of spherical harmonics. parameters : (..., k) tensor, if 'parameters' in `returns` Parameters of the linear combination """ fmap = torch.as_tensor(fmap) dim = dim or fmap.dim() shape = fmap.shape[-dim:] batch = fmap.shape[:-dim] backend = utils.backend(fmap) dims = list(range(-dim, 0)) if mask is not None: mask = ~mask # make it a mask of background voxels # compute gradients gmap = diff(fmap, dim=dims, side='f', bound='dct2') if mask is not None: gmap[..., mask, :] = 0 gmap = gmap.reshape([*batch, -1]) # compute basis of spherical harmonics basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = diff(b, dim=dims, side='f', bound='dct2') if mask is not None: b[..., mask, :] = 0 b = b.reshape([b.shape[0], *batch, -1]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) # solve system prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0] # > (*batch, k) # rebuild basis (without taking gradients) basis = [] for i in range(1, max_order + 1): b = spherical_harmonics(shape, i, isocenter, **backend) b = utils.movedim(b, -1, 0) b = b.reshape([b.shape[0], *batch, *shape]) basis.append(b) basis = torch.cat(basis, 0) basis = utils.movedim(basis, 0, -1) # (*batch, vox*dim, k) comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim)) comb = comb[..., 0] fmap = fmap - comb returns = returns.split('+') out = [] for ret in returns: if ret == 'corrected': out.append(fmap) elif ret == 'correction': out.append(comb) elif ret[0] == 'p': out.append(prm) return out[0] if len(out) == 1 else tuple(out)
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias
def forward(self, image, **overload): """ Parameters ---------- image : (batch, channel, *shape) tensor Input image overload : dict All parameters defined at build time can be overridden at call time Returns ------- warped : (batch, channel, *shape) tensor Deformed image grid : (batch, *shape, 3) tensor Resampling grid """ image = torch.as_tensor(image) dim = image.dim() - 2 batch, channel, *shape = image.shape info = {'dtype': image.dtype, 'device': image.device} # get arguments opt_grid = { 'dim': dim, 'shape': shape, 'amplitude': overload.get('vel_amplitude', self.grid.amplitude), 'fwhm': overload.get('vel_fwhm', self.grid.fwhm), 'bound': overload.get('vel_bound', self.grid.bound), 'interpolation': overload.get('interpolation', self.grid.interpolation), 'dtype': overload.get('dtype', self.grid.dtype), 'device': overload.get('device', self.grid.device), } opt_affine = { 'dim': dim, 'translation': overload.get('translation', self.affine.translation), 'rotation': overload.get('rotation', self.affine.rotation), 'zoom': overload.get('zoom', self.affine.zoom), 'shear': overload.get('shear', self.affine.shear), 'dtype': overload.get('dtype', self.affine.dtype), 'device': overload.get('device', self.affine.device), } opt_pull = { 'bound': overload.get('image_bound', self.pull.bound), 'interpolation': overload.get('interpolation', self.pull.interpolation), } grid = self.grid(batch, **opt_grid) aff = self.affine(batch, **opt_affine) # shift center of rotation aff_shift = torch.cat( (torch.eye(dim, **info), -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2), dim=1) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = matvec(lin, grid) + off # pull warped = self.pull(image, grid, **opt_pull) return warped, grid