def _atlas_align(dat, mat, rigid=True, pth_atlas=None): """Affinely align image to some atlas space. Parameters ---------- dat : [N, ...], tensor_like List of image volumes. mat : [N, ...], tensor_like List of affine matrices. rigid = bool, default=True Do rigid alignment, else does rigid+isotropic scaling. pth_atlas : str, optional Path to atlas image to match to. Uses Brain T1w atlas by default. Returns ---------- mat_a : (N, 4, 4) tensor_like Transformation aligning to MNI space as M_mni\M_mov. mat_mni : (4, 4), tensor_like Affine matrix of MNI image. dim_mni : (3,), tuple, list, tensor_like Image dimensions of MNI image. mat_cso : (N, 4, 4) tensor_like CSO transformation. """ if pth_atlas is None: # Get path to nitorch's T1w intensity atlas pth_atlas = fetch_data('atlas_t1') # Get number of input images N = len(dat) # Append atlas at the end of input data dat_mni, mat_mni, _ = _format_input(pth_atlas, device=dat[0].device, rand=True, cutoff=(0.0005, 0.9995)) dat.append(dat_mni[0]) mat.append(mat_mni[0]) # Align to MNI atlas. group = 'CSO' _, mat_mni, dim_mni, q = _affine_align(dat, mat, group=group, samp=(3, 1.5), cost_fun='nmi', fix=N, verbose=False, mean_space=False) # Remove atlas q = q[:N, ...] dat = dat[:N] mat = mat[:N] # Get matrix representation mat_cso = expm(q, affine_basis(group=group)) if rigid: # Extract only rigid part group = 'SE' q = q[..., :6] # Get matrix representation mat_a = expm(q, affine_basis(group=group)) return mat_a, mat_mni, dim_mni, mat_cso
def _init_reg(x, sett): """ Initialise registration. """ # Total number of observations N = sum([len(xn) for xn in x]) # Set rigid affine basis sett.rigid_basis = affine_basis( group='SE', device=sett.device, dtype=torch.float64) fix = 0 # Fixed image index # Make input for nitorch affine align imgs = [] for c in range(len(x)): for n in range(len(x[c])): imgs.append([x[c][n].dat, x[c][n].mat]) if sett.do_coreg and N > 1: # Align images, pairwise, to fixed image (fix) t0 = _print_info('init-reg', sett, 'co', 'begin', N) mat_a = affine_align(imgs, fix=fix, device=sett.device)[1] # Apply coreg transform i = 0 for c in range(len(x)): for n in range(len(x[c])): imgs[i][1] = imgs[i][1].solve(mat_a[i, ...])[0] i += 1 _print_info('init-reg', sett, 'co', 'finished', N, t0) if sett.do_atlas_align: # Align fixed image to atlas space, and apply transformation to # all images t0 = _print_info('init-reg', sett, 'atlas', 'begin', N) imgs1 = [imgs[fix]] _, mat_a, _, mat_cso = atlas_align(imgs1, rigid=sett.atlas_rigid, device=sett.device) _print_info('init-reg', sett, 'atlas', 'finished', N, t0) # Apply atlas registration transform i = 0 for c in range(len(x)): for n in range(len(x[c])): imgs[i][1] = imgs[i][1].solve(mat_a)[0] i += 1 # Modify image affine (label uses the same as the image, so no need to modify that one) i = 0 for c in range(len(x)): for n in range(len(x[c])): x[c][n].mat = imgs[i][1] i += 1 # Init rigid parameters (for unified rigid registration) for c in range(len(x)): # Loop over channels for n in range(len(x[c])): # Loop over observations of channel c x[c][n].rigid_q = torch.zeros(sett.rigid_basis.shape[0], device=sett.device, dtype=torch.float64) return x, sett
def free(self): """Free the next batch/ladder of parameters""" if not self.freeable(): return nb_prm = len(self.optdat) if hasattr(self, 'optdat') else 0 nb_t = self.dim nb_r = self.dim * (self.dim - 1) // 2 nb_z = self.dim self.dat = self.dat.detach() if hasattr(self, 'optdat'): self.optdat = self.optdat.detach() self.dat = torch.cat([self.optdat.detach(), self.dat[nb_prm:]]) if nb_prm == 0: print('Free translations') self.optdat = torch.nn.Parameter(self.dat[:nb_t], requires_grad=True) self.dat = torch.cat([self.optdat, self.dat[nb_t:]]) self.basis = spatial.affine_basis('T', self.dim) elif nb_prm == nb_t: print('Free rotations') self.optdat = torch.nn.Parameter(self.dat[:nb_t + nb_r], requires_grad=True) self.dat = torch.cat([self.optdat, self.dat[nb_t + nb_r:]]) self.basis = spatial.affine_basis('SE', self.dim) elif nb_prm == nb_t + nb_r: print('Free isotropic scaling') self.optdat = torch.nn.Parameter(self.dat[:nb_t + nb_r + 1], requires_grad=True) self.dat = torch.cat([self.optdat, self.dat[nb_t + nb_r + 1:]]) self.basis = spatial.affine_basis('CSO', self.dim) elif nb_prm == nb_t + nb_r + 1: print('Free full affine') self.dat[nb_t + nb_r] /= nb_z**0.5 self.dat[nb_t + nb_r + 1] = self.dat[nb_t + nb_r] self.dat[nb_t + nb_r + 2] = self.dat[nb_t + nb_r] self.optdat = torch.nn.Parameter(self.dat, requires_grad=True) self.dat = self.optdat self.basis = spatial.affine_basis('Aff+', self.dim)
class Translation(Linear): name = 'translation' basis = spatial.affine_basis('T', 3) nb_prm = staticmethod(lambda dim: dim)
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def register(fixed=None, moving=None, dim=None, loss='mse', basis='CSO', optim='ogm', max_iter=500, lr=1, ls=6, plot=False, klosure=RegisterStep, logaff=None, verbose=True): """Affine registration between two images using Lie groups. Parameters ---------- fixed : (..., K, *spatial) tensor Fixed image moving : (..., K, *spatial) tensor Moving image dim : int, default=`fixed.dim() - 1` Number of spatial dimensions loss : {'mse', 'cat'} or OptimizationLoss, default='mse' 'mse': Mean-squared error 'cat': Categorical cross-entropy optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='ogm' 'gn' : Gauss-Newton 'gd' : Gradient descent 'momentum' : Gradient descent with momentum 'nesterov' : Nesterov-accelerated gradient descent 'ogm' : Optimized gradient descent (Kim & Fessler) 'lbfgs' : Limited-memory BFGS max_iter : int, default=100 Maximum number of Gauss-Newton or Gradient descent iterations lr : float, default=1 Learning rate. ls : int, default=6 Number of line search iterations. plot : bool, default=False Plot progress Returns ------- logaff : (...) tensor Displacement field. """ # If no inputs provided: demo "circle to square" if fixed is None or moving is None: fixed, moving = phantoms.demo_register(cat=(loss == 'cat')) # init tensors fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) basis = spatial.affine_basis(basis, dim, **utils.backend(fixed)) if logaff is None: logaff = torch.zeros(len(basis), **utils.backend(fixed)) # logaff = torch.zeros(12, **utils.backend(fixed)) # init optimizer optim = regutils.make_iteroptim_affine(optim, lr, ls, max_iter) # init loss loss = losses.make_loss(loss, dim) # optimize if verbose: print( f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}' ) print('-' * 63) closure = klosure(moving, fixed, loss, basis=basis, verbose=verbose, plot=plot, max_iter=optim.max_iter) logaff = optim.iter(logaff, closure) if verbose: print('') return logaff
def diffeo(source, target, group='SE', image_loss=None, vel_loss=None, pull=None, preproc=False, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """ Parameters ---------- source : path or tensor or (tensor, affine) target : path or tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' image_loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr: float, default=1e-4 optim_affine : bool, default=True Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (D+1, D+1) tensor Initial velocity of the diffeomorphic transform. The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)` moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] source = rescale(source) targe = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, velocity, moved
class Affine(Linear): name = 'affine' basis = spatial.affine_basis('Aff+', 3) nb_prm = staticmethod(lambda dim: dim * (dim + 1))
def _mean_space(Mat, Dim, vx=None): """Compute a (mean) model space from individual spaces. Args: Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4). Dim (torch.tensor): N subjects' dimensions (N, 3). vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input). Returns: mat (torch.tensor): Mean orientation matrix (4, 4). dim (torch.tensor): Mean dimensions (3,). vx (torch.tensor): Mean voxel size (3,). Authors: John Ashburner, as part of the SPM12 software. """ device = Mat.device dtype = Mat.dtype N = Mat.shape[0] # Number of subjects inf = float('inf') one = torch.tensor(1.0, device=device, dtype=dtype) if vx is None: vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype) if isinstance(vx, float) or isinstance(vx, int): vx = (vx, ) * 3 if isinstance(vx, tuple) and len(vx) == 3: vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype) # To float64 Mat = Mat.type(dtype) Dim = Dim.type(dtype) # Get affine basis basis = 'SE' dim = 3 if Dim[0, 2] > 1 else 2 B = affine_basis(basis, dim, device=device, dtype=dtype) # Find combination of 90 degree rotations and flips that brings all # the matrices closest to axial Mat0 = Mat.clone() pmatrix = torch.tensor( [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]], device=device) for n in range(N): # Loop over subjects vx1 = voxel_size(Mat[n, ...]) R = Mat[n, ...].mm( torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1] minss = inf minR = torch.eye(3, dtype=dtype, device=device) for i in range(6): # Permute (= 'rotate + flip') axes R1 = torch.zeros((3, 3), dtype=dtype, device=device) R1[pmatrix[i, 0], 0] = 1 R1[pmatrix[i, 1], 1] = 1 R1[pmatrix[i, 2], 2] = 1 for j in range(8): # Mirror (= 'flip') axes fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1] F = torch.diag(torch.tensor(fd, dtype=dtype, device=device)) R2 = F.mm(R1) ss = torch.sum((R.mm(R2.inverse()) - torch.eye(3, dtype=dtype, device=device))**2) if ss < minss: minss = ss minR = R2 rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1)) R2 = minR.inverse() R22 = R2.mm((torch.div( torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') - 1) * rdim) minR = torch.cat((R2, R22), dim=1) minR = torch.cat( (minR, torch.tensor([0, 0, 0, 1], device=device, dtype=dtype)[None, ...]), dim=0) Mat[n, ...] = Mat[n, ...].mm(minR) # Average of the matrices in Mat mat = meanm(Mat) # If average involves shears, then find the closest matrix that does not # require them. C_ix = torch.tensor( [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], device=device) # column-major ordering from (4, 4) tensor p = _imatrix(mat) if torch.sum(p[[9, 10, 11]]**2) > 1e-8: B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype) B2[0, 0, 0] = 1 B2[1, 1, 1] = 1 B2[2, 2, 2] = 1 p = torch.zeros(9, device=device, dtype=dtype) for n_iter in range(10000): # Rotations + Translations R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True) # Zooms Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True) M = R.mm(Z) dM = torch.zeros((4, 4, 9), device=device, dtype=dtype) for n in range(6): dM[..., n] = dR[n, ...].mm(Z) for n in range(3): dM[..., 6 + n] = R.mm(dZ[n, ...]) dM = dM.reshape((16, 9)) d = M.flatten() - mat.flatten() gr = dM.t().mm(d[..., None]) Hes = dM.t().mm(dM) p = p - lmdiv(Hes, gr)[:, 0] if torch.sum(gr**2) < 1e-8: break mat = M.clone() # Set required voxel size vx_out = vx.clone() vx = voxel_size(mat) vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)] mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag()) vx = voxel_size(mat) # Ensure that the FoV covers all images, with a few voxels to spare mn_all = torch.zeros([3, N], device=device, dtype=dtype) mx_all = torch.zeros([3, N], device=device, dtype=dtype) for n in range(N): dm = Dim[n, ...] corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]], [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]], [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]], [1, 1, 1, 1, 1, 1, 1, 1]], device=device, dtype=dtype) M = lmdiv(mat, Mat0[n]) vx1 = M[:-1, :].mm(corners) mx_all[..., n] = torch.max(vx1, dim=1)[0] mn_all[..., n] = torch.min(vx1, dim=1)[0] mx = mx_all.max(dim=1)[0] mn = mn_all.min(dim=1)[0] mx = torch.ceil(mx) mn = torch.floor(mn) # Make output dimensions and orientation matrix dim = mx - mn + 1 # Output dimensions off = torch.tensor([0, 0, 0], device=device, dtype=dtype) mat = mat.mm( torch.tensor([[1, 0, 0, mn[0] - (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)], [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]], device=device, dtype=dtype)) return mat, dim, vx
def make_basis(name, dim, **backend): name = convert_basis(name) return spatial.affine_basis(name, dim, **backend)
def diffeo(source, target, group='SE', origin='center', image_loss=None, vel_loss=None, pull=None, optim_affine=True, max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None): """Diffeomorphic registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) The source (moving) image, with shape (batch, channel, *spatial). target : tensor or (tensor, affine) The target (fixed) image, with shape (batch, channel, *spatial). group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid' Affine sub-group to optimize. origin : {'native', 'center'}, default='center' Whether to rotate about the origin of the world-space ('native') or the center of the target field-of-view ('center'). When the origin of the world-space is far off (say you are registering smaller blocks cropped from a larger MRI), it can be beneficiary to rotate about the center of the FOV. image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss() A loss function that takestwo inputs of shape (batch, channel, *spatial). vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss() Either a factor to muultiply the bending loss with or a loss function that takes two inputs of shape (batch, channel, *spatial). pull : dict interpolation : int, default=1 Interpolation order bound : bound_like, default='dct2' Boundary condition extrapolate : bool, default=False Extrapolate out-of-bound data using the boundary conditions. max_iter : int, default=1000 Maximum number of iterations lr : float, default=0.1 Initial learning rate. min_lr : float, default=1e-7 Minimum learning rate. The optimization is stopped once this learning rate is reached. device : {'cpu', 'cuda', 'cuda:<id>'}, optional Backend to use init : ([batch], nb_prm) tensor_like, default=0 Initial guess for the affine parameters. Returns ------- q : (batch, nb_prm) tensor Parameters aff : (batch, D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (batch, *shape, D) tensor Initial velocity moved : tensor Source image moved to target space. """ group = affine_group_converter(group) pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] dim = source.dim() - 2 # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) min_lr = core.utils.make_list(min_lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) scheduler = ReduceLROnPlateau(optim) def forward(): moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) return loss_val # Optim loop loss_avg = 0 for n_iter in range(1, max_iter + 1): optim.zero_grad(set_to_none=True) loss_val = forward() loss_val.backward() optim.step(forward) with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 scheduler.step(loss_avg) print('{:4d} {:12.6f} | lr={:g} ' .format(n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 if (optim.param_groups[0]['lr'] < min_lr[0] and (len(optim.param_groups) == 1 or optim.param_groups[1]['lr'] < min_lr[1])): print('\nConverged.') break print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() return (parameters.detach(), aff.detach(), velocity.detach(), moved.detach())
def affine(source, target, group='SE', loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=0.1, scheduler=ReduceLROnPlateau): """Affine registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=True) identity = spatial.identity_grid(target.shape[2:], **backend) def pull(q): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], identity) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if loss is None: loss_fn = nni.MutualInfoLoss() loss = lambda x, y: loss_fn(x, y) optim = torch.optim.Adam([parameters], lr=lr) if scheduler is not None: scheduler = scheduler(optim) # Optim loop loss_val = core.constants.inf for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters) loss_val = loss(moved, target) loss_val.backward() optim.step() if scheduler is not None and n_iter % 10 == 0: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_val) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_val.item(), optim.param_groups[0]['lr']), end='\r') print('') with torch.no_grad(): moved = pull(parameters) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, moved
class Rigid(Linear): name = 'rigid' basis = spatial.affine_basis('SE', 3) nb_prm = staticmethod(lambda dim: dim * (dim + 1) // 2)
def _affine_align(dat, mat, cost_fun='nmi', group='SE', mean_space=False, samp=(3, 1.5), optimiser='powell', fix=0, verbose=False, fov=None, mx_int=1023, raw=False, jitter=False, fwhm=7.0): """Affine registration of a collection of images. Parameters ---------- dat : [N, ...], tensor_like List of image volumes. mat : [N, ...], tensor_like List of affine matrices. cost_fun : str, default='nmi' Pairwise methods: * 'nmi' : Normalised Mutual Information * 'mi' : Mutual Information * 'ncc' : Normalised Cross Correlation * 'ecc' : Entropy Correlation Coefficient Groupwise methods: * 'njtv' : Normalised Joint Total variation * 'jtv' : Joint Total variation group : str, default='SE' * 'T' : Translations * 'SO' : Special Orthogonal (rotations) * 'SE' : Special Euclidean (translations + rotations) * 'D' : Dilations (translations + isotropic scalings) * 'CSO' : Conformal Special Orthogonal (translations + rotations + isotropic scalings) * 'SL' : Special Linear (rotations + isovolumic zooms + shears) * 'GL+' : General Linear [det>0] (rotations + zooms + shears) * 'Aff+': Affine [det>0] (translations + rotations + zooms + shears) mean_space : bool, default=False Optimise a mean-space fit, only available if cost_fun='njtv'. samp : (float, ), default=(3, 1.5) Optimisation sampling steps (mm). optimiser : str, default='powell' 'powell' : Optimisation method. fix : int, default=0 Index of image to used as fixed image, not used if mean_space=True. verbose : bool, default=False Show registration results. fov : (2,) tuple, default=None A tuple with affine matrix (tensor_like) and dimensions (tuple) of mean space. mx_int : int, default=1023 This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=1023 -> H.shape = (1024, 1024)). This is only done if cost_fun is histogram-based. raw : bool, default=False Do no processing of input images -> work on raw data. jitter : bool, default=False Add random jittering to resampling grid. fwhm : float, default=7 Full-width at half max of Gaussian kernel, for smoothing histogram. Returns ---------- mat_a : (N, 4, 4), tensor_like Affine alignment matrices. mat_fix : (4, 4), tensor_like Affine matrix of fixed image. dim_fix : (3,), tuple, list, tensor_like Image dimensions of fixed image. q : (N, Nq), tensor_like Lie parameters. """ with torch.no_grad(): device = dat[0].device # Parse algorithm options opt = {'optimiser': optimiser, 'cost_fun': cost_fun, 'samp': samp, 'fix': fix, 'mean_space': mean_space, 'verbose': verbose, 'fov': fov, 'group' : group, 'raw': raw, 'jitter': jitter, 'mx_int' : mx_int, 'fwhm' : fwhm} if not isinstance(opt['samp'], (list, tuple)): opt['samp'] = (opt['samp'], ) # Some very basic sanity checks N = len(dat) # Number of input scans mov = list(range(N)) # Indices of images if opt['cost_fun'] in _costs_hist and opt['mean_space']: raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun'])) # Get affine basis B = affine_basis(group=opt['group'], device=device) Nq = B.shape[0] # Load data dat = _data_loader(dat, mat, opt) # Define fixed image space (mat_fix, dim_fix, vx_fix) if opt['mean_space']: # Use a mean-space if opt['fov']: # Mean-space given mat_fix = opt['fov'][0] dim_fix = opt['fov'][1] else: # Compute mean-space mat_fix, dim_fix = _get_mean_space(dat, mat) arg_grid = dim_fix else: # Use one of the input images mat_fix = mat[opt['fix']] dim_fix = dat[opt['fix']].shape[:3] mov.remove(opt['fix']) arg_grid = dat[opt['fix']] # Get voxel size of fixed image vx_fix = voxel_size(mat_fix) # Initial guess for registration parameter q = torch.zeros((N, Nq), dtype=torch.float64, device=device) if N < 2: # Return identity mat_a = torch.zeros((N, 4, 4), dtype=torch.float64, device=device) for n in range(N): mat_a[m, ...] = expm(q[n, ...], basis=B) return mat_a, mat_fix, dim_fix # Do registration for s in opt['samp']: # Loop over sub-sampling level # Get possibly sub-sampled fixed image, and its resampling grid dat_fix, grid = _get_dat_grid( arg_grid, vx_fix, s, jitter=opt['jitter'], device=device) # Do optimisation q, args = _fit_q(q, dat_fix, grid, mat_fix, dat, mat, mov, B, s, opt) # To matrix form mat_a = torch.zeros((N, 4, 4), dtype=torch.float64, device=device) for n in range(N): mat_a[n, ...] = expm(q[n, ...], basis=B) return mat_a, mat_fix, dim_fix, q
class Similitude(Linear): name = 'similitude' basis = spatial.affine_basis('CSO', 3) nb_prm = staticmethod(lambda dim: dim * (dim + 1) // 2 + 1)
def _test_cost(dat, mat, cost_fun='nmi', group='SE', mean_space=False, samp=2, ix_par=0, jitter=False, x_step=0.1, x_mn_mx=30, verbose=False, mx_int=1023, raw=False, fwhm=7.0): """Check cost function behaviour by keeping one image fixed and re-aligning a second image by modifying one of the affine parameters. Plots cost vs. aligment when finished. """ with torch.no_grad(): device = dat[0].device # Parse algorithm options opt = {'cost_fun': cost_fun, 'samp': samp, 'mean_space': mean_space, 'verbose': verbose, 'raw': raw, 'mx_int' : mx_int, 'fwhm' : fwhm} if not isinstance(opt['samp'], (list, tuple)): opt['samp'] = (opt['samp'], ) # Some very basic sanity checks N = len(dat) if N != 2: raise ValueError('N != 2') mov = list(range(N)) # Indices of images fix_img = 0 mov_img = 1 if opt['cost_fun'] in _costs_hist and opt['mean_space']: raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun'])) # Load data dat = _data_loader(dat, mat, opt) # Get full 12 parameter affine basis B = affine_basis(group='Aff+', device=device) Nq = B.shape[0] # Range of parameter x = torch.arange(start=-x_mn_mx, end=x_mn_mx, step=x_step, dtype=torch.float32) if opt['mean_space']: # Use mean-space, so make sure that maximum misalignment is represented # in the input to _get_mean_space() mat_mn = torch.zeros(Nq, dtype=torch.float64, device=device) mat_mx = torch.zeros(Nq, dtype=torch.float64, device=device) mat_mn[ix_par] = -x_mn_mx mat_mx[ix_par] = x_mn_mx mat1 = [expm(mat_mn, B).mm(mat[mov_img]), expm(mat_mx, B).mm(mat[mov_img])] # Compute mean-space dat.append(torch.tensor(dat[mov_img].shape, dtype=torch.float32, device=device)) dat.append(torch.tensor(dat[mov_img].shape, dtype=torch.float32, device=device)) mat_fix, dim_fix = _get_mean_space(dat, mat + mat1) dat = dat[:2] arg_grid = dim_fix else: mat_fix = mat[fix_img] dim_fix = dat[fix_img].shape[:3] mov.remove(fix_img) arg_grid = dat[fix_img] # Get voxel size of fixed image vx_fix = voxel_size(mat_fix) # Initial guess q = torch.zeros((N, Nq), dtype=torch.float64, device=device) # Get subsampled fixed image and its resampling grid dat_fix, grid = _get_dat_grid(arg_grid, vx_fix, samp=opt['samp'][-1], jitter=jitter, device=device) # Iterate over a range of values costs = np.zeros(len(x)) fig_ax = None # Used for visualisation for i, xi in enumerate(x): # Change affine matrix a little bit q[fix_img, ix_par] = xi # Compute cost costs[i], res = _compute_cost( q, grid, dat_fix, mat_fix, dat, mat, mov, opt['cost_fun'], B, opt['mx_int'], opt['fwhm'], return_res=True) if opt['verbose']: fig_ax = show_slices(res, fig_ax=fig_ax, fig_num=1, cmap='coolwarm', title='x=' + str(xi)) # print(costs[i]) # Plot results if plt is None: return fig, ax = plt.subplots(num=2) ax.plot(x, costs) ax.set(xlabel='Value q[' + str(ix_par) + ']', ylabel='Cost', title=opt['cost_fun'].upper() + ' cost function (mean_space=' + str(opt['mean_space']) + ')') ax.grid() plt.show()
def basis(self): if self._basis is None: self._basis = spatial.affine_basis(self._basis_name, self.dim, **utils.backend(self.dat)) return self._basis
def ffd(source, target, grid_shape=10, group='SE', image_loss=None, def_loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """FFD (= cubic spline) registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dft') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: affine_parameters = torch.as_tensor(init, **backend).clone().detach() affine_parameters = affine_parameters.reshape([batch, nb_prm]) else: affine_parameters = torch.zeros([batch, nb_prm], **backend) affine_parameters = nn.Parameter(affine_parameters, requires_grad=optim_affine) grid_shape = core.pyutils.make_list(grid_shape, dim) grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend) grid_parameters = nn.Parameter(grid_parameters, requires_grad=True) def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved def exp(prm): disp = spatial.resize_grid(prm, type='displacement', shape=target.shape[2:], interpolation=3, bound='dft') grid = disp + spatial.identity_grid(target.shape[2:], **backend) return disp, grid # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(def_loss): def_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if def_loss is None else def_loss def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{ 'params': affine_parameters, 'lr': lr[1] }, { 'params': grid_parameters, 'lr': lr[0] }] if optim_affine else [grid_parameters] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # with torch.no_grad(): # disp, grid = exp(grid_parameters) # moved = pull(affine_parameters, grid) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(max_iter): loss_val0 = loss_val zero_grad_([affine_parameters, grid_parameters]) disp, grid = exp(grid_parameters) moved = pull(affine_parameters, grid) loss_val = image_loss(moved, target) + def_loss(disp[0]) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: # print(affine_parameters) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(affine_parameters, grid) aff = core.linalg.expm(affine_parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return affine_parameters, aff, grid_parameters, moved