def _atlas_align(dat, mat, rigid=True, pth_atlas=None): """Affinely align image to some atlas space. Parameters ---------- dat : [N, ...], tensor_like List of image volumes. mat : [N, ...], tensor_like List of affine matrices. rigid = bool, default=True Do rigid alignment, else does rigid+isotropic scaling. pth_atlas : str, optional Path to atlas image to match to. Uses Brain T1w atlas by default. Returns ---------- mat_a : (N, 4, 4) tensor_like Transformation aligning to MNI space as M_mni\M_mov. mat_mni : (4, 4), tensor_like Affine matrix of MNI image. dim_mni : (3,), tuple, list, tensor_like Image dimensions of MNI image. mat_cso : (N, 4, 4) tensor_like CSO transformation. """ if pth_atlas is None: # Get path to nitorch's T1w intensity atlas pth_atlas = fetch_data('atlas_t1') # Get number of input images N = len(dat) # Append atlas at the end of input data dat_mni, mat_mni, _ = _format_input(pth_atlas, device=dat[0].device, rand=True, cutoff=(0.0005, 0.9995)) dat.append(dat_mni[0]) mat.append(mat_mni[0]) # Align to MNI atlas. group = 'CSO' _, mat_mni, dim_mni, q = _affine_align(dat, mat, group=group, samp=(3, 1.5), cost_fun='nmi', fix=N, verbose=False, mean_space=False) # Remove atlas q = q[:N, ...] dat = dat[:N] mat = mat[:N] # Get matrix representation mat_cso = expm(q, affine_basis(group=group)) if rigid: # Extract only rigid part group = 'SE' q = q[..., :6] # Get matrix representation mat_a = expm(q, affine_basis(group=group)) return mat_a, mat_mni, dim_mni, mat_cso
def write_transforms(options): """Write transformations (affine and nonlin) on disk""" nonlin = None affine = None for trf in options.transformations: if isinstance(trf, struct.NonLinear): nonlin = trf else: affine = trf if affine: q = affine.dat B = affine.basis lin = linalg.expm(q, B) if torch.is_tensor(affine.shift): # include shift shift = affine.shift.to(dtype=lin.dtype, device=lin.device) eye = torch.eye(3, dtype=lin.dtype, device=lin.device) lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift) io.transforms.savef(lin.cpu(), affine.output, type=2) if nonlin: affine = nonlin.affine shape = nonlin.shape if isinstance(nonlin, struct.FFD): factor = [s/g for s, g in zip(shape, nonlin.dat.shape[:-1])] affine, _ = spatial.affine_resize(affine, shape, factor) io.volumes.savef(nonlin.dat.cpu(), nonlin.output, affine=affine.cpu())
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def _test_cost(dat, mat, cost_fun='nmi', group='SE', mean_space=False, samp=2, ix_par=0, jitter=False, x_step=0.1, x_mn_mx=30, verbose=False, mx_int=1023, raw=False, fwhm=7.0): """Check cost function behaviour by keeping one image fixed and re-aligning a second image by modifying one of the affine parameters. Plots cost vs. aligment when finished. """ with torch.no_grad(): device = dat[0].device # Parse algorithm options opt = {'cost_fun': cost_fun, 'samp': samp, 'mean_space': mean_space, 'verbose': verbose, 'raw': raw, 'mx_int' : mx_int, 'fwhm' : fwhm} if not isinstance(opt['samp'], (list, tuple)): opt['samp'] = (opt['samp'], ) # Some very basic sanity checks N = len(dat) if N != 2: raise ValueError('N != 2') mov = list(range(N)) # Indices of images fix_img = 0 mov_img = 1 if opt['cost_fun'] in _costs_hist and opt['mean_space']: raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun'])) # Load data dat = _data_loader(dat, mat, opt) # Get full 12 parameter affine basis B = affine_basis(group='Aff+', device=device) Nq = B.shape[0] # Range of parameter x = torch.arange(start=-x_mn_mx, end=x_mn_mx, step=x_step, dtype=torch.float32) if opt['mean_space']: # Use mean-space, so make sure that maximum misalignment is represented # in the input to _get_mean_space() mat_mn = torch.zeros(Nq, dtype=torch.float64, device=device) mat_mx = torch.zeros(Nq, dtype=torch.float64, device=device) mat_mn[ix_par] = -x_mn_mx mat_mx[ix_par] = x_mn_mx mat1 = [expm(mat_mn, B).mm(mat[mov_img]), expm(mat_mx, B).mm(mat[mov_img])] # Compute mean-space dat.append(torch.tensor(dat[mov_img].shape, dtype=torch.float32, device=device)) dat.append(torch.tensor(dat[mov_img].shape, dtype=torch.float32, device=device)) mat_fix, dim_fix = _get_mean_space(dat, mat + mat1) dat = dat[:2] arg_grid = dim_fix else: mat_fix = mat[fix_img] dim_fix = dat[fix_img].shape[:3] mov.remove(fix_img) arg_grid = dat[fix_img] # Get voxel size of fixed image vx_fix = voxel_size(mat_fix) # Initial guess q = torch.zeros((N, Nq), dtype=torch.float64, device=device) # Get subsampled fixed image and its resampling grid dat_fix, grid = _get_dat_grid(arg_grid, vx_fix, samp=opt['samp'][-1], jitter=jitter, device=device) # Iterate over a range of values costs = np.zeros(len(x)) fig_ax = None # Used for visualisation for i, xi in enumerate(x): # Change affine matrix a little bit q[fix_img, ix_par] = xi # Compute cost costs[i], res = _compute_cost( q, grid, dat_fix, mat_fix, dat, mat, mov, opt['cost_fun'], B, opt['mx_int'], opt['fwhm'], return_res=True) if opt['verbose']: fig_ax = show_slices(res, fig_ax=fig_ax, fig_num=1, cmap='coolwarm', title='x=' + str(xi)) # print(costs[i]) # Plot results if plt is None: return fig, ax = plt.subplots(num=2) ax.plot(x, costs) ax.set(xlabel='Value q[' + str(ix_par) + ']', ylabel='Cost', title=opt['cost_fun'].upper() + ' cost function (mean_space=' + str(opt['mean_space']) + ')') ax.grid() plt.show()
def _affine_align(dat, mat, cost_fun='nmi', group='SE', mean_space=False, samp=(3, 1.5), optimiser='powell', fix=0, verbose=False, fov=None, mx_int=1023, raw=False, jitter=False, fwhm=7.0): """Affine registration of a collection of images. Parameters ---------- dat : [N, ...], tensor_like List of image volumes. mat : [N, ...], tensor_like List of affine matrices. cost_fun : str, default='nmi' Pairwise methods: * 'nmi' : Normalised Mutual Information * 'mi' : Mutual Information * 'ncc' : Normalised Cross Correlation * 'ecc' : Entropy Correlation Coefficient Groupwise methods: * 'njtv' : Normalised Joint Total variation * 'jtv' : Joint Total variation group : str, default='SE' * 'T' : Translations * 'SO' : Special Orthogonal (rotations) * 'SE' : Special Euclidean (translations + rotations) * 'D' : Dilations (translations + isotropic scalings) * 'CSO' : Conformal Special Orthogonal (translations + rotations + isotropic scalings) * 'SL' : Special Linear (rotations + isovolumic zooms + shears) * 'GL+' : General Linear [det>0] (rotations + zooms + shears) * 'Aff+': Affine [det>0] (translations + rotations + zooms + shears) mean_space : bool, default=False Optimise a mean-space fit, only available if cost_fun='njtv'. samp : (float, ), default=(3, 1.5) Optimisation sampling steps (mm). optimiser : str, default='powell' 'powell' : Optimisation method. fix : int, default=0 Index of image to used as fixed image, not used if mean_space=True. verbose : bool, default=False Show registration results. fov : (2,) tuple, default=None A tuple with affine matrix (tensor_like) and dimensions (tuple) of mean space. mx_int : int, default=1023 This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=1023 -> H.shape = (1024, 1024)). This is only done if cost_fun is histogram-based. raw : bool, default=False Do no processing of input images -> work on raw data. jitter : bool, default=False Add random jittering to resampling grid. fwhm : float, default=7 Full-width at half max of Gaussian kernel, for smoothing histogram. Returns ---------- mat_a : (N, 4, 4), tensor_like Affine alignment matrices. mat_fix : (4, 4), tensor_like Affine matrix of fixed image. dim_fix : (3,), tuple, list, tensor_like Image dimensions of fixed image. q : (N, Nq), tensor_like Lie parameters. """ with torch.no_grad(): device = dat[0].device # Parse algorithm options opt = {'optimiser': optimiser, 'cost_fun': cost_fun, 'samp': samp, 'fix': fix, 'mean_space': mean_space, 'verbose': verbose, 'fov': fov, 'group' : group, 'raw': raw, 'jitter': jitter, 'mx_int' : mx_int, 'fwhm' : fwhm} if not isinstance(opt['samp'], (list, tuple)): opt['samp'] = (opt['samp'], ) # Some very basic sanity checks N = len(dat) # Number of input scans mov = list(range(N)) # Indices of images if opt['cost_fun'] in _costs_hist and opt['mean_space']: raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun'])) # Get affine basis B = affine_basis(group=opt['group'], device=device) Nq = B.shape[0] # Load data dat = _data_loader(dat, mat, opt) # Define fixed image space (mat_fix, dim_fix, vx_fix) if opt['mean_space']: # Use a mean-space if opt['fov']: # Mean-space given mat_fix = opt['fov'][0] dim_fix = opt['fov'][1] else: # Compute mean-space mat_fix, dim_fix = _get_mean_space(dat, mat) arg_grid = dim_fix else: # Use one of the input images mat_fix = mat[opt['fix']] dim_fix = dat[opt['fix']].shape[:3] mov.remove(opt['fix']) arg_grid = dat[opt['fix']] # Get voxel size of fixed image vx_fix = voxel_size(mat_fix) # Initial guess for registration parameter q = torch.zeros((N, Nq), dtype=torch.float64, device=device) if N < 2: # Return identity mat_a = torch.zeros((N, 4, 4), dtype=torch.float64, device=device) for n in range(N): mat_a[m, ...] = expm(q[n, ...], basis=B) return mat_a, mat_fix, dim_fix # Do registration for s in opt['samp']: # Loop over sub-sampling level # Get possibly sub-sampled fixed image, and its resampling grid dat_fix, grid = _get_dat_grid( arg_grid, vx_fix, s, jitter=opt['jitter'], device=device) # Do optimisation q, args = _fit_q(q, dat_fix, grid, mat_fix, dat, mat, mov, B, s, opt) # To matrix form mat_a = torch.zeros((N, 4, 4), dtype=torch.float64, device=device) for n in range(N): mat_a[n, ...] = expm(q[n, ...], basis=B) return mat_a, mat_fix, dim_fix, q
def _compute_cost(q, grid0, dat_fix, mat_fix, dat, mat, mov, cost_fun, B, mx_int, fwhm, return_res=False): """Compute registration cost function. Parameters ---------- q : (N, Nq) tensor_like Lie algebra of affine registration fit. grid0 : (X1, Y1, Z1) tensor_like Sub-sampled image data's resampling grid. dat_fix : (X1, Y1, Z1) tensor_like Fixed image data. mat_fix : (4, 4) tensor_like Fixed affine matrix. dat : [N,] tensor_like List of input images. mat : [N,] tensor_like List of affine matrices. mov : [N,] int Indices of moving images. cost_fun : str Cost function to compute (see run_affine_reg). B : (Nq, N, N) tensor_like Affine basis. mx_int : int This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=511 -> H.shape = (512, 512)). fwhm : float Full-width at half max of Gaussian kernel, for smoothing histogram. return_res : bool, default=False Return registration results for plotting. Returns ---------- c : float Cost of aligning images with current estimate of q. If optimiser='powell', array_like, else tensor_like. res : tensor_like Registration results, for visualisation (only if return_res=True). """ # Init device = grid0.device q = q.flatten() was_numpy = False if isinstance(q, np.ndarray): was_numpy = True q = torch.from_numpy(q).to(device) # To torch tensor dm_fix = dat_fix.shape Nq = B.shape[0] N = torch.tensor(len(dat), device=device, dtype=torch.float32) # For modulating NJTV cost if cost_fun in _costs_edge: jtv = dat_fix.clone() if cost_fun == 'njtv': njtv = -dat_fix.sqrt() for i, m in enumerate(mov): # Loop over moving images # Get affine matrix mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B) # Compose matrices M = mat_a.mm(mat_fix).solve(mat[m])[0].type( torch.float32) # mat_mov\mat_a*mat_fix # Transform fixed grid grid = affine_matvec(M, grid0) # Resample to fixed grid dat_new = grid_pull(dat[m][None, None, ...], grid[None, ...], bound='dft', extrapolate=False, interpolation=1)[0, 0, ...] if cost_fun in _costs_edge: jtv += dat_new if cost_fun == 'njtv': njtv -= dat_new.sqrt() # Compute the cost function res = None if cost_fun in _costs_hist: # Histogram based costs # ---------- # Compute joint histogram # OBS: This function expects both images to have the same max and min intesities, # this is ensured by the _data_loader() function. H = _hist_2d(dat_fix, dat_new, mx_int, fwhm) res = H # Get probabilities pxy = H / H.sum() px = pxy.sum(dim=0, keepdim=True) py = pxy.sum(dim=1, keepdim=True) # Compute cost if cost_fun == 'mi': # Mutual information mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) c = -mi elif cost_fun == 'ecc': # Entropy Correlation Coefficient # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997). # "Multimodality image registration by maximisation of mutual # information". IEEE Transactions on Medical Imaging 16(2):187-198 mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) ecc = -2 * mi / (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) c = -ecc elif cost_fun == 'nmi': # Normalised Mutual Information # Studholme, Hill & Hawkes (1998). # "A normalized entropy measure of 3-D medical image alignment". # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143. nmi = (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2()) c = -nmi elif cost_fun == 'ncc': # Normalised Cross Correlation i = torch.arange(1, pxy.shape[0] + 1, device=device, dtype=torch.float32) j = torch.arange(1, pxy.shape[1] + 1, device=device, dtype=torch.float32) m1 = torch.sum(py * i[..., None]) m2 = torch.sum(px * j[None, ...]) sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2)) sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2)) i, j = torch.meshgrid(i - m1, j - m2) ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2) c = -ncc elif cost_fun in _costs_edge: # Normalised Joint Total Variation # M Brudfors, Y Balbastre, J Ashburner (2020). # "Groupwise Multimodal Image Registration using Joint Total Variation". # in MIUA 2020. jtv.sqrt_() if cost_fun == 'njtv': njtv += torch.sqrt(N) * jtv res = njtv c = torch.sum(njtv) else: res = jtv c = torch.sum(jtv) # _ = show_slices(res, fig_num=1, cmap='coolwarm') # Can be uncommented for testing if was_numpy: # Back to numpy array c = c.cpu().numpy() if return_res: return c, res else: return c
def write_data(options): device = torch.device(options.device) backend = dict(dtype=torch.float, device='cpu') need_inv = False for loss in options.losses: if loss.fixed and (loss.fixed.resliced or loss.fixed.updated): need_inv = True break # affine matrix lin = None for trf in options.transformations: if isinstance(trf, struct.Linear): q = trf.dat.to(**backend) B = trf.basis.to(**backend) lin = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(3, **backend) lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift) break # non-linear displacement field d = None id = None d_aff = None for trf in options.transformations: if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') if need_inv: id = grid_inv(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if need_inv: id = spatial.exp(d[None], displacement=True, inverse=True)[0] d = spatial.exp(d[None], displacement=True)[0] d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: moving = match.moving fixed = match.fixed prm = dict(interpolation=moving.interpolation, bound=moving.bound, extrapolate=moving.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=d, affine=d_aff) if moving.updated: update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm) if moving.resliced: reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm) if not fixed: continue prm = dict(interpolation=fixed.interpolation, bound=fixed.bound, extrapolate=fixed.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=id, affine=d_aff) if fixed.updated: update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm) if fixed.resliced: reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss