def _warp_image1(image, target, shape=None, affine=None, nonlin=None, backward=False, reslice=False): """Returns the warped image, with channel dimension last""" # build transform aff_right = target aff_left = spatial.affine_inv(image.affine) aff = None if affine: # exp = affine.iexp if backward else affine.exp exp = affine.exp aff = exp(recompute=False, cache_result=True) if backward: aff = spatial.affine_inv(aff) if nonlin: if affine: if affine.position[0].lower() in ('ms' if backward else 'fs'): aff_right = spatial.affine_matmul(aff, aff_right) if affine.position[0].lower() in ('fs' if backward else 'ms'): aff_left = spatial.affine_matmul(aff_left, aff) exp = nonlin.iexp if backward else nonlin.exp phi = exp(recompute=False, cache_result=True) aff_left = spatial.affine_matmul(aff_left, nonlin.affine) aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right) if _almost_identity(aff_right) and nonlin.shape == shape: phi = nonlin.add_identity(phi) else: tmp = spatial.affine_grid(aff_right, shape) phi = regutils.smart_pull_grid(phi, tmp).add_(tmp) del tmp if not _almost_identity(aff_left): phi = spatial.affine_matvec(aff_left, phi) else: # no nonlin: single affine even if position == 'symmetric' if reslice: aff = spatial.affine_matmul(aff, aff_right) aff = spatial.affine_matmul(aff_left, aff) phi = spatial.affine_grid(aff, shape) else: phi = None # warp image if phi is not None: warped = image.pull(phi) else: warped = image.dat # write to disk if len(warped) == 1: warped = warped[0] else: warped = utils.movedim(warped, 0, -1) return warped
def load_and_pull(volume, aff, shape, dtype=None, device=None): """ Parameters ---------- volume : Volume3D aff : (D+1,D+1) tensor shape : (D,) tuple Returns ------- dat : tensor """ backend = dict(dtype=dtype or aff.dtype, device=device or aff.device) aff = aff.to(**backend) identity = torch.eye(aff.shape[-1], **backend) fdata = volume.fdata(cache=False, **backend) inshape = fdata.shape inaff = volume.affine.to(**backend) aff = core.linalg.lmdiv(inaff, aff) if torch.allclose(aff, identity) and tuple(shape) == tuple(inshape): return fdata else: grid = spatial.affine_grid(aff, shape) return spatial.grid_pull(fdata[None, None, ...], grid[None, ...])[0, 0]
def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid
def _msk_fov(dat, mat, mat0, dim0): """Mask field-of-view (FOV) of image data according to other image's FOV. Parameters ---------- dat : (X, Y, Z), tensor Image data. mat : (4, 4), tensor Image's affine. mat0 : (4, 4), tensor Other image's affine. dim0 : (3, ), list/tuple Other image's dimensions. Returns ------- dat : (X, Y, Z), tensor Masked image data. """ dim = dat.shape M = lmdiv(mat0, mat) # mat0\mat1 grid = affine_grid(M, dim) msk = (grid[..., 0] >= 1) & (grid[..., 0] <= dim0[0]) & \ (grid[..., 1] >= 1) & (grid[..., 1] <= dim0[1]) & \ (grid[..., 2] >= 1) & (grid[..., 2] <= dim0[2]) dat[~msk] = 0 return dat
def forward(self, affine, **overload): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix overload : dict All parameters of the module can be overridden at call time. Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 info = {'dtype': affine.dtype, 'device': affine.device} shape = make_list(overload.get('shape', self.shape), nb_dim) shift = overload.get('shift', self.shift) if shift: affine_shift = torch.cat(( torch.eye(nb_dim, **info), -torch.as_tensor(shape, **info)[:, None]/2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def _init_y_dat(x, y, sett): """ Make initial guesses of reconstucted image(s) using b-spline interpolation, with averaging if more than one observation per channel. """ dim_y = y[0].dim mat_y = y[0].mat for c in range(len(x)): dat_y = torch.zeros(dim_y, dtype=torch.float32, device=sett.device) num_x = len(x[c]) sm = torch.zeros_like(dat_y) for n in range(num_x): # Get image data dat = x[c][n].dat[None, None, ...] # Make output grid mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y grid = affine_grid(mat.type(dat.dtype), dim_y) # Do resampling mn = torch.min(dat) mx = torch.max(dat) dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=1) dat[dat < mn] = mn dat[dat > mx] = mx sm = sm + (dat[0, 0, ...].round() != 0) dat_y = dat_y + dat[0, 0, ...] sm[sm == 0] = 1 y[c].dat = dat_y / sm return y
def _resample_inplane(x, sett): """Force in-plane resolution of observed data to be greater or equal to recon vx. """ if sett.force_inplane_res and sett.max_iter > 0: I = torch.eye(4, device=sett.device, dtype=torch.float64) for c in range(len(x)): for n in range(len(x[c])): # get image data dat = x[c][n].dat[None, None, ...] mat_x = x[c][n].mat dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64) vx_x = voxel_size(mat_x) # make grid D = I.clone() for i in range(3): D[i, i] = sett.vx / vx_x[i] if D[i, i] < 1.0: D[i, i] = 1 if float((I - D).abs().sum()) < 1e-4: continue mat_x = mat_x.matmul(D) dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist() grid = affine_grid(D.type(dat.dtype), dim_x) # resample dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0) # do label if x[c][n].label is not None: x[c][n].label[0] = _warp_label(x[c][n].label[0], grid) # assign x[c][n].dat = dat[0, 0, ...] x[c][n].mat = mat_x x[c][n].dim = dim_x return x
def forward(self, affine, shape=None): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix shape : sequence[int], default=self.shape Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 backend = {'dtype': affine.dtype, 'device': affine.device} shape = shape or self.shape if self.shift: affine_shift = torch.eye(nb_dim + 1, **backend) affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend) affine_shift[:nb_dim, -1].sub(1).div(2).neg() affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def forward(self, source, target, source_affine=None, target_affine=None): """ Parameters ---------- source : (sX, sY, sZ) tensor or str target : (tX, tY, tZ) tensor or str source_affine : (4, 4) tensor, optional target_affine : (4, 4) tensor, optional Returns ------- warped : (tX, tY, tZ) tensor Source warped to target velocity : (vX, vY, vZ, 3) tensor Stationary velocity field affine : (4, 4) tensor, optional Affine of the velocity space """ if self.verbose: print('Preprocessing... ', end='', flush=True) source, source_affine, source_orig, source_affine_orig \ = self.load(source, source_affine) target, target_affine, target_orig, target_affine_orig \ = self.load(target, target_affine) source = spatial.reslice(source, source_affine, target_affine, target.shape) if self.verbose: print('done.', flush=True) print('Registering... ', end='', flush=True) source = source[None, None] target = target[None, None] warped, vel, grid = super().forward(source, target) if self.verbose: print('done.', flush=True) del source, target, warped vel = vel[0] grid = grid[0] grid -= spatial.identity_grid(grid.shape[:-1], dtype=grid.dtype, device=grid.device) right_affine = target_affine.inverse() @ target_affine_orig right_affine = spatial.affine_grid(right_affine, target_orig.shape) grid = spatial.grid_pull(utils.movedim(grid, -1, 0), right_affine, bound='nearest', extrapolate=True) grid = utils.movedim(grid, 0, -1).add_(right_affine) left_affine = source_affine_orig.inverse() @ target_affine grid = spatial.affine_matvec(left_affine, grid) warped = spatial.grid_pull(source_orig, grid) return warped, vel, target_affine
def _init_y_label(x, y, sett): """Make initial guess of labels. """ dim_y = y[0].dim mat_y = y[0].mat for c in range(len(x)): n = 0 if x[c][n].label is not None: # Make output grid mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y grid = affine_grid(mat.type(x[c][n].dat.dtype), dim_y) # Do resampling y[c].label = _warp_label(x[c][n].label[0], grid) return y
def _reslice_dat_3d(dat, affine, dim_out, interpolation='linear', bound='zero', extrapolate=False): """Reslice 3D image data. Parameters ---------- dat : (Xi, Yi, Zi), tensor_like Input image data. affine : (4, 4), tensor_like Affine transformation that maps from voxels in output image to voxels in input image. dim_out : (Xo, Yo, Zo), list or tuple Output image dimensions. interpolation : str, default='linear' Interpolation order. bound : str, default='zero' Boundary condition. extrapolate : bool, default=False Extrapolate out-of-bounds data. Returns ------- dat : (dim_out), tensor_like Resliced image data. """ if len(dat.shape) != 3: raise ValueError('Input error: len(dat.shape) != 3') grid = affine_grid(affine, dim_out).type(dat.dtype) grid = grid[None, ...] dat = dat[None, None, ...] dat = grid_pull(dat, grid, bound=bound, interpolation=interpolation, extrapolate=extrapolate) dat = dat[0, 0, ...] return dat
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False): """Convert derivatives wrt warped image in loss space to to derivatives wrt parameters parameters: g (tensor) : gradient wrt warped image h (tensor) : hessian wrt warped image moving (Image) : moving image phi (tensor) : dense (exponentiated) displacement field left (matrix) : left affine right (matrix) : right affine inv (bool) : whether we're in a backward symmetric pass returns: g (tensor) : pushed gradient h (tensor) : pushed hessian gmu (tensor) : rotated spatial gradients """ if inv: g = g.neg_() # build bits of warp dim = phi.shape[-1] fixed_shape = g.shape[-dim:] moving_shape = moving.shape # differentiate wrt δ in: Left o Phi o (Id + δ) o Right # we'll then propagate them through Phi by scaling and squaring if right is not None: right = spatial.affine_grid(right, fixed_shape) g = regutils.smart_push(g, right, shape=self.shape) h = regutils.smart_push(h, right, shape=self.shape) del right phi_left = spatial.identity_grid(self.shape, **utils.backend(phi)) phi_left += phi if left is not None: phi_left = spatial.affine_matvec(left, phi_left) mugrad = moving.pull_grad(phi_left, rotate=False) del phi_left mugrad = _rotate_grad(mugrad, left, phi) return g, h, mugrad
def _crop_y(y, sett): """ Crop output images FOV to a fixed dimension Args: y (_output()): _output data. Returns: y (_output()): Cropped output data. """ if not sett.crop: return y device = sett.device # Output image information mat_y = y[0].mat vx_y = voxel_size(mat_y) # Define cropped FOV mat_mu, dim_mu = _bb_atlas('atlas_t1', fov=sett.fov, dtype=torch.float64, device=device) # Modulate atlas with voxel size mat_vx = torch.diag(torch.cat(( vx_y, torch.ones(1, dtype=torch.float64, device=device)))) mat_mu = mat_mu.mm(mat_vx) dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze() # Make output grid M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype) grid = affine_grid(M, dim_mu)[None, ...] # Crop for c in range(len(y)): y[c].dat = grid_pull(y[c].dat[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] # Do labels? if y[c].label is not None: y[c].label = grid_pull(y[c].label[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] y[c].mat = mat_mu y[c].dim = tuple(dim_mu.int().tolist()) return y
def compute_grid(self, mat_native, dim_native): """Computes resampling grid for pulling/pushing from/to common space. Parameters ---------- mat_native : (1, dim + 1, dim + 1) tensor Native image affine matrix. dim_native : [3, ] sequence Native image dimensions. Returns ---------- grid : (batch, *spatial, dim) tensor Resampling grid. """ self.mean_mat = self.mean_mat.type(mat_native.dtype).to( mat_native.device) mat = mat_native.solve(self.mean_mat)[0] grid = spatial.affine_grid(mat, dim_native) return grid
def smart_grid(aff, shape, inshape=None, force=False): """Generate a sampling grid iff it is not the identity. Parameters ---------- aff : (D+1, D+1) tensor Affine transformation matrix (voxels to voxels) shape : (D,) tuple[int] Output shape inshape : (D,) tuple[int], optional Input shape Returns ------- grid : (*shape, D) tensor or None Sampling grid """ backend = dict(dtype=aff.dtype, device=aff.device) identity = torch.eye(aff.shape[-1], **backend) inshape = inshape or shape if not force and torch.allclose(aff, identity) and shape == inshape: return None return spatial.affine_grid(aff, shape)
def fit(x, y, sett): """ Fit model. This runs the iterative denoising/super-resolution algorithm and, at the end, writes the reconstructed images to disk. If the maximum number of iterations are set to zero, the initial guesses of the reconstructed images will be written to disk (acquired with b-spline interpolation), no denoising/super-resolution will be applied. Returns: dat_y (torch.tensor): Reconstructed image data as float32, (dim_y, C). mat_y (torch.tensor): Reconstructed affine matrix, (4, 4). pth_y ([str, ...]): Paths to reconstructed images. R (torch.tensor): Rigid matrices (N, 4, 4). label (torch.tensor): Reconstructed label image, (dim_y). pth_label str: Path to reconstructed label image. """ with torch.no_grad(): # Total number of observations N = sum([len(xn) for xn in x]) # Sanity check scaling parameter if not isinstance(sett.reg_scl, torch.Tensor): sett.reg_scl = torch.tensor(sett.reg_scl, dtype=torch.float32, device=sett.device) sett.reg_scl = sett.reg_scl.reshape(1) # Defines a coarse-to-fine scaling of regularisation sett = _get_sched(N, sett) # For visualisation fig_ax_nll = None fig_ax_jtv = None # Scale lambda cnt_scl = 0 for c in range(len(x)): y[c].lam = sett.reg_scl[cnt_scl] * y[c].lam0 if sett.max_iter > 0: # Get ADMM step-size rho = _step_size(x, y, sett, verbose=True) # Get ADMM variables z, w = _admm_aux(y, sett) # ---------- # ITERATE: # Updates model in an alternating fashion, until a convergence threshold is met # on the model negative log-likelihood. # ---------- obj = torch.zeros(sett.max_iter, 3, dtype=torch.float64, device=sett.device) tmp = torch.zeros_like( y[0].dat) # for holding rhs in y-update, and jtv in u-update t_iter = timer() if sett.do_print else 0 cnt_scl_iter = 0 # To ensure we do, at least, a fixed number of iterations for each scale for n_iter in range(sett.max_iter): if n_iter == 0: t00 = _print_info('fit-start', sett, len(x), N) # PRINT # ---------- # UPDATE: image # ---------- y, z, w, tmp, obj = _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett) # Show JTV if sett.show_jtv: fig_ax_jtv = show_slices(img=tmp, fig_ax=fig_ax_jtv, title='JTV', cmap='coolwarm', fig_num=98) # ---------- # Check convergence # ---------- if sett.plot_conv: # Plot algorithm convergence fig_ax_nll = plot_convergence( vals=obj[:n_iter + 1, :], fig_ax=fig_ax_nll, fig_num=99, legend=['-ln(p(y|x))', '-ln(p(x|y))', '-ln(p(y))']) gain = get_gain(obj[:n_iter + 1, 0], monotonicity='decreasing') t_iter = _print_info('fit-ll', sett, n_iter, obj[n_iter, :], gain, t_iter) # Converged? if cnt_scl >= (sett.reg_scl.numel() - 1) and cnt_scl_iter > 20 \ and ((gain.abs() < sett.tolerance) or (n_iter >= (sett.max_iter - 1))): countdown0 -= 1 if countdown0 == 0: _ = _print_info('fit-finish', sett, t00, n_iter) break # Finished else: countdown0 = 6 # ---------- # UPDATE: even/odd scaling # ---------- if sett.scaling: t0 = _print_info('fit-update', sett, 's', n_iter) # PRINT # Do update x, _ = _update_scaling(x, y, sett, max_niter_gn=1, num_linesearch=6, verbose=0) _ = _print_info('fit-done', sett, t0) # PRINT # Print parameter estimates _ = _print_info('scl-param', sett, x, t0) # ---------- # UPDATE: rigid_q # ---------- if sett.unified_rigid and n_iter > 0 \ and (n_iter % sett.rigid_mod) == 0: t0 = _print_info('fit-update', sett, 'q', n_iter) # PRINT x, _ = _update_rigid(x, y, sett, mean_correct=False, max_niter_gn=1, num_linesearch=6, verbose=0, samp=sett.rigid_samp) _ = _print_info('fit-done', sett, t0) # PRINT # Print parameter estimates _ = _print_info('reg-param', sett, x, t0) # ---------- # Coarse-to-fine scaling of regularisation # ---------- if cnt_scl + 1 < len(sett.reg_scl) and cnt_scl_iter > 16 and\ gain.abs() < 1e-3: countdown1 -= 1 if countdown1 == 0: cnt_scl_iter = 0 cnt_scl += 1 # Coarse-to-fine scaling of lambda for c in range(len(x)): y[c].lam = sett.reg_scl[cnt_scl] * y[c].lam0 # Also update ADMM step-size rho = _step_size(x, y, sett) else: countdown1 = 6 cnt_scl_iter += 1 # ---------- # Some post-processing # ---------- if sett.clean_fov: # Zero outside FOV in reconstructed data for c in range(len(x)): msk_fov = torch.ones(y[c].dim, dtype=torch.bool, device=sett.device) for n in range(len(x[c])): # Map to voxels in low-res image M = x[c][n].po.rigid.mm(x[c][n].mat).solve( y[c].mat)[0].inverse() grid = affine_grid(M.type(x[c][n].dat.dtype), y[c].dim)[None, ...] # Mask of low-res image FOV projected into high-res space msk_fov = msk_fov & \ (grid[0, ..., 0] >= 1) & (grid[0, ..., 0] <= x[c][n].dim[0]) & \ (grid[0, ..., 1] >= 1) & (grid[0, ..., 1] <= x[c][n].dim[1]) & \ (grid[0, ..., 2] >= 1) & (grid[0, ..., 2] <= x[c][n].dim[2]) # if x[c][n].ct: # # Resample low-res image into high-res space # dat_c = grid_pull(x[c][n].dat[None, None, ...], # grid, bound=sett.bound, # extrapolate=False, # interpolation=sett.interpolation)[0, 0, ...] # # Set voxels inside the FOV that are positive in the # # low-res data but negative in the high-res, to the # # their original values # msk = msk_fov & (dat_c >= 0) & (y[c].dat < 0) # y[c].dat[msk] = tmp[msk] # Zero voxels outside projected FOV y[c].dat[~msk_fov] = 0.0 # # Possibly crop reconstructed data # y = _crop_y(y, sett) # ---------- # Get rigid matrices # ---------- R = torch.zeros((N, 4, 4), device=sett.device, dtype=torch.float64) cnt = 0 for c in range(len(x)): for n in range(len(x[c])): R[cnt, ...] = _expm(x[c][n].rigid_q, sett.rigid_basis) cnt += 1 # ---------- # Possibly write reconstruction results to disk # ---------- dat_y, pth_y, label, pth_label = _write_data(x, y, sett, jtv=tmp) return dat_y, y[0].mat, pth_y, R, label, pth_label
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def do_vel(self, vel, grad=False, hess=False, in_line_search=False): """Forward pass for updating the nonlinear component""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== if self.affine: aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False) aff_pos = self.affine.position[0].lower() else: aff_pos = 'x' aff0 = iaff0 = torch.eye(self.nonlin.dim + 1) vel0 = vel if any(loss.backward for loss in self.losses): phi0, iphi0 = self.nonlin.exp2(vel0, recompute=True, cache_result=not in_line_search) ivel0 = -vel0 else: phi0 = self.nonlin.exp(vel0, recompute=True, cache_result=not in_line_search) iphi0 = ivel0 = None aff0 = aff0.to(phi0) iaff0 = iaff0.to(phi0) # ============================================================== # ACCUMULATE DERIVATIVES # ============================================================== has_printed = False for loss in self.losses: # ========================================================== # ONE LOSS COMPONENT # ========================================================== moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, vel00 = iphi0, iaff0, ivel0 else: phi00, aff00, vel00 = phi0, aff0, vel0 # ---------------------------------------------------------- # build left and right affine # ---------------------------------------------------------- aff_right = fixed.affine if aff_pos in 'fs': # affine position: fixed or symmetric aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left = self.nonlin.affine if aff_pos in 'ms': # affine position: moving or symmetric aff_left = aff00 @ self.nonlin.affine aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: aff_right = None phi = spatial.add_identity_grid(phi00) disp = phi00 else: phi = spatial.affine_grid(aff_right, fixed.shape) disp = regutils.smart_pull_grid(phi00, phi) phi += disp if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: aff_left = None else: phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim, title=f'(nonlin) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt phi # ---------------------------------------------------------- if grad or hess: g, h, mugrad = self.nonlin.propagate_grad( g, h, moving, phi00, aff_left, aff_right, inv=loss.backward) g = regutils.jg(mugrad, g) h = regutils.jhj(mugrad, h) if isinstance(self.nonlin, SVFModel): # propagate backward by scaling and squaring g, h = spatial.exp_backward(vel00, g, h, steps=self.nonlin.steps) sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # ============================================================== # REGULARIZATION # ============================================================== vgrad = self.nonlin.regulariser(vel0) llv = 0.5 * vel0.flatten().dot(vgrad.flatten()) if grad: sumgrad += vgrad del vgrad # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += llv sumloss += self.lla self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): llv = llv.item() ll = sumloss.item() lla = self.lla if in_line_search: line = '(search) | ' else: line = '(nonlin) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.llv = llv self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def _update_scaling(x, y, sett, max_niter_gn=1, num_linesearch=4, verbose=0): """ Updates an even/odd slice scaling parameter using Gauss-Newton optimisation. Args: verbose (bool, optional): Verbose for testing, defaults to False. Returns: sll (torch.tensor): Log-likelihood. """ # Update rigid parameters, for all input images sll = torch.tensor(0, device=sett.device, dtype=torch.float64) ll = torch.tensor(0, device=sett.device, dtype=torch.float64) for c in range(len(x)): # Loop over channels for n_x in range(len(x[c])): # Loop over repeats if x[c][n_x].ct: # Do not optimise scaling for CT data continue # Parameters dim_thick = x[c][n_x].po.dim_thick tau = x[c][n_x].tau scl = x[c][n_x].po.scl smo_ker = x[c][n_x].po.smo_ker dim_thick = x[c][n_x].po.dim_thick ratio = x[c][n_x].po.ratio dim = x[c][n_x].po.dim_yx mat_yx = x[c][n_x].po.mat_yx mat_y = x[c][n_x].po.mat_y rigid = _expm(x[c][n_x].rigid_q, sett.rigid_basis) mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx # Observed data dat_x = x[c][n_x].dat msk = dat_x != 0 # Get even/odd data xo = _even_odd(dat_x, 'odd', dim_thick) mo = _even_odd(msk, 'odd', dim_thick) xo = xo[mo] xe = _even_odd(dat_x, 'even', dim_thick) me = _even_odd(msk, 'even', dim_thick) xe = xe[me] # Get reconstruction (without scaling) grid = affine_grid(mat.type(torch.float32), dim, jitter=False) dat_y = grid_pull(y[c].dat[None, None, ...], grid[None, ...], bound=sett.bound, interpolation=sett.interpolation, extrapolate=False) dat_y = F.conv3d(dat_y, smo_ker, stride=ratio)[0, 0, ...] # Apply scaling dat_y = _apply_scaling(dat_y, scl, dim_thick) for n_gn in range( max_niter_gn): # Loop over Gauss-Newton iterations # Log-likelihood ll = 0.5 * tau * torch.sum( (dat_x[msk] - dat_y[msk])**2, dtype=torch.float64) if verbose >= 2: # Show images show_slices(torch.stack((dat_x, dat_y, (dat_x - dat_y)**2), 3), fig_num=666, colorbar=False, flip=False) # Get even/odd data yo = _even_odd(dat_y, 'odd', dim_thick) yo = yo[mo] ye = _even_odd(dat_y, 'even', dim_thick) ye = ye[me] # Gradient gr = tau * (torch.sum(ye * (xe - ye), dtype=torch.float64) - torch.sum(yo * (xo - yo), dtype=torch.float64)) # Hessian Hes = tau * (torch.sum(ye**2, dtype=torch.float64) + torch.sum(yo**2, dtype=torch.float64)) # Compute Gauss-Newton update step Update = gr / Hes # Do update.. old_scl = scl.clone() old_ll = ll.clone() armijo = torch.tensor(1.0, device=sett.device, dtype=old_scl.dtype) if num_linesearch == 0: # ..without a line-search scl = old_scl - armijo * Update if verbose >= 1: print('c={}, n={}, gn={} | exp(s)={}'.format( c, n_x, n_gn, round(scl.exp(), 5))) else: # ..using a line-search for n_ls in range(num_linesearch): # Take step scl = old_scl - armijo * Update # Apply scaling dat_y = _apply_scaling(dat_y, scl - old_scl, dim_thick) # Compute matching term ll = 0.5 * tau * torch.sum( (dat_x[msk] - dat_y[msk])**2, dtype=torch.float64) if verbose >= 2: # Show images show_slices(torch.stack( (dat_x, dat_y, (dat_x - dat_y)**2), 3), fig_num=666, colorbar=False, flip=False) # Matching improved? if ll < old_ll: # Better fit! if verbose >= 1: print( 'c={}, n={}, gn={}, ls={} | :) ll={:0.2f}, ll-oll={:0.2f} | exp(s)={} armijo={}' .format(c, n_x, n_gn, n_ls, ll, ll - old_ll, round(scl.exp(), 5), round(armijo, 4))) break else: # Reset parameters scl = old_scl ll = old_ll armijo *= 0.5 if verbose >= 1 and n_ls == num_linesearch - 1: print( 'c={}, n={}, gn={}, ls={} | :( ll={:0.2f}, ll-oll={:0.2f} | exp(s)={} armijo={}' .format(c, n_x, n_gn, n_ls, ll, ll - old_ll, round(old_scl.exp(), 5), round(armijo, 4))) # Update scaling in projection operator x[c][n_x].po.scl = scl # Accumulate neg log-lik sll += ll return x, sll
def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: aff00, gaff00 = iaff0, igaff0 else: aff00, gaff00 = aff0, gaff0 # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- aff = aff00 @ fixed.affine aff = linalg.lmdiv(moving.affine, aff) gaff = gaff00 @ fixed.affine gaff = linalg.lmdiv(moving.affine, gaff) phi = spatial.affine_grid(aff, fixed.shape) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.preview else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.preview, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h # compose with spatial gradients if grad or hess: mugrad = moving.pull_grad(phi, rotate=False) g, h = compose_grad(g, h, mugrad, gaff) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla lla = lla ll = sumloss.item() self.loss_value = ll if self.verbose and (self.verbose > 1 or not in_line_search): if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def _update_rigid_channel(xc, yc, sett, max_niter_gn=1, num_linesearch=4, verbose=0, samp=3, c=1): """ Updates the rigid parameters for all images of one channel. Args: c (int): Channel index. rigid_basis (torch.tensor) max_niter_gn (int, optional): Max Gauss-Newton iterations, defaults to 1. num_linesearch (int, optional): Max line-search iterations, defaults to 4. verbose (bool, optional): Show registration results, defaults to 0. 0: No verbose 1: Print convergence info to console 2: Plot registration results using matplotlib samp (int, optional): Sub-sample data, defaults to 3. Returns: sll (torch.tensor): Log-likelihood. """ # Parameters device = yc.dat.device method = sett.method num_q = sett.rigid_basis.shape[0] lkp = [[0, 3, 4], [3, 1, 5], [4, 5, 2]] one = torch.tensor(1.0, device=device, dtype=torch.float64) sll = torch.tensor(0, device=device, dtype=torch.float64) for n_x in range(len(xc)): # Loop over repeats # Lowres image data dat_x = xc[n_x].dat[None, None, ...] # Parameters q = xc[n_x].rigid_q tau = xc[n_x].tau armijo = torch.tensor(1, device=device, dtype=q.dtype) po = _proj_info(xc[n_x].po.dim_y, xc[n_x].po.mat_y, xc[n_x].po.dim_x, xc[n_x].po.mat_x, rigid=xc[n_x].po.rigid, prof_ip=sett.profile_ip, prof_tp=sett.profile_tp, gap=sett.gap, device=device, scl=xc[n_x].po.scl, samp=samp) # Method if method == 'super-resolution': dim = po.dim_yx mat = po.mat_yx elif method == 'denoising': dim = po.dim_x mat = po.mat_x # Do sub-sampling? if samp > 0 and po.D_x is not None: # Lowres grid = affine_grid(po.D_x.type(torch.float32), po.dim_x) dat_x = grid_pull(xc[n_x].dat[None, None, ...], grid[None, ...], bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] if n_x == 0 and po.D_y is not None: # Highres (only for superres) grid = affine_grid(po.D_y.type(dtype=torch.float32), po.dim_y) dat_y = grid_pull(yc.dat[None, None, ...], grid[None, ...], bound='zero', extrapolate=False, interpolation=0) else: dat_y = yc.dat[None, None, ...] else: dat_x = xc[n_x].dat dat_y = yc.dat[None, None, ...] # Pre-compute super-resolution Hessian (CtC)? CtC = None if method == 'super-resolution': CtC = F.conv3d(torch.ones(( 1, 1, ) + dim, device=device, dtype=torch.float32), po.smo_ker, stride=po.ratio) CtC = F.conv_transpose3d(CtC, po.smo_ker, stride=po.ratio)[0, 0, ...] # Get identity grid id_x = identity_grid(dim, dtype=torch.float32, device=device, jitter=False) for n_gn in range(max_niter_gn): # Loop over Gauss-Newton iterations # Differentiate Rq w.r.t. q (store in d_rigid_q) rigid, d_rigid = _expm(q, sett.rigid_basis, grad_X=True) d_rigid = d_rigid.permute( (1, 2, 0)) # make compatible with old affine_basis d_rigid_q = torch.zeros(4, 4, num_q, device=device, dtype=torch.float64) for i in range(num_q): d_rigid_q[:, :, i] = d_rigid[:, :, i].mm(mat).solve( po.mat_y)[0] # mat_y\d_rigid*mat # Compute gradient and Hessian gr = torch.zeros(num_q, 1, device=device, dtype=torch.float64) Hes = torch.zeros(num_q, num_q, device=device, dtype=torch.float64) # Compute matching-term part (log-likelihood) ll, gr_m, Hes_m = _rigid_match(dat_x, dat_y, po, tau, rigid, sett, diff=True, verbose=verbose, CtC=CtC) # Multiply with d_rigid_q (chain-rule) dAff = [] for i in range(num_q): dAff.append([]) for d in range(3): dAff[i].append(d_rigid_q[d, 0, i] * id_x[:, :, :, 0] + \ d_rigid_q[d, 1, i] * id_x[:, :, :, 1] + \ d_rigid_q[d, 2, i] * id_x[:, :, :, 2] + \ d_rigid_q[d, 3, i]) # Add d_rigid_q to gradient for d in range(3): for i in range(num_q): gr[i] += torch.sum(gr_m[:, :, :, d] * dAff[i][d], dtype=torch.float64) # Add d_rigid_q to Hessian for d1 in range(3): for d2 in range(3): for i1 in range(num_q): tmp1 = Hes_m[:, :, :, lkp[d1][d2]] * dAff[i1][d1] for i2 in range(i1, num_q): Hes[i1, i2] += torch.sum(tmp1 * dAff[i2][d2], dtype=torch.float64) # Fill in missing triangle for i1 in range(num_q): for i2 in range(i1 + 1, num_q): Hes[i2, i1] = Hes[i1, i2] # # Regularise diagonal of Hessian # Hes += 1e-5*Hes.diag().max()*torch.eye(num_q, dtype=Hes.dtype, device=device) # Compute Gauss-Newton update step Update = gr.solve(Hes)[0][:, 0] # Do update.. old_ll = ll.clone() old_q = q.clone() old_rigid = rigid.clone() if num_linesearch == 0: # ..without a line-search q = old_q - armijo * Update rigid = _expm(q, sett.rigid_basis) if verbose >= 1: print('c={}, n={}, gn={} | q={}'.format( c, n_x, n_gn, round(q, 7).tolist())) else: # ..using a line-search for n_ls in range(num_linesearch): # Take step q = old_q - armijo * Update # Compute matching term rigid = _expm(q, sett.rigid_basis) ll = _rigid_match(dat_x, dat_y, po, tau, rigid, sett, verbose=verbose)[0] # Matching improved? if ll < old_ll: # Better fit! armijo = torch.min(1.25 * armijo, one) if verbose >= 1: print( 'c={}, n={}, gn={}, ls={} | :) ll={:0.2f}, ll-oll={:0.2f} | q={} armijo={}' .format(c, n_x, n_gn, n_ls, ll, ll - old_ll, round(q, 7).tolist(), round(armijo, 4))) break else: # Reset parameters ll = old_ll q = old_q rigid = old_rigid armijo *= 0.5 if n_ls == num_linesearch - 1 and verbose >= 1: print( 'c={}, n={}, gn={}, ls={} | :( ll={:0.2f}, ll-oll={:0.2f} | q={} armijo={}' .format(c, n_x, n_gn, n_ls, ll, ll - old_ll, round(q, 7).tolist(), round(armijo, 4))) # Assign xc[n_x].rigid_q = q xc[n_x].po.rigid = rigid # Accumulate neg log-lik sll += ll return xc, sll
def _rigid_match(dat_x, dat_y, po, tau, rigid, sett, CtC=None, diff=False, verbose=0): """ Computes the rigid matching term, and its gradient and Hessian (if requested). Args: dat_x (torch.tensor): Observed data (X0, Y0, Z0). dat_y (torch.tensor): Reconstructed data (X1, Y1, Z1). po (ProjOp): Projection operator. tau (torch.tensor): Noice precision. CtC (torch.tensor, optional): CtC(ones), used for super-res gradient calculation. Defaults to None. rigid (torch.tensor): Rigid transformation matrix (4, 4). diff (bool, optional): Compute derivatives, defaults to False. verbose (bool, optional): Show registration results, defaults to 0. 0: No verbose 1: Print convergence info to console 2: Plot registration results using matplotlib Returns: ll (torch.tensor): Log-likelihood. gr (torch.tensor): Gradient (dim_x, 3). Hes (torch.tensor): Hessian (dim_x, 6). """ # Projection info mat_x = po.mat_x mat_y = po.mat_y mat_yx = po.mat_yx dim_x = po.dim_x dim_yx = po.dim_yx ratio = po.ratio smo_ker = po.smo_ker dim_thick = po.dim_thick scl = po.scl # Init output ll = None gr = None Hes = None if sett.method == 'super-resolution': extrapolate = False dim = dim_yx mat = mat_yx elif sett.method == 'denoising': extrapolate = False dim = dim_x mat = mat_x # Get grid mat = rigid.mm(mat).solve(mat_y)[0] # mat_y\rigid*mat grid = affine_grid(mat.type(torch.float32), dim, jitter=False) # Warp y and compute spatial derivatives dat_yx = grid_pull(dat_y, grid[None, ...], bound=sett.bound, extrapolate=extrapolate, interpolation=sett.interpolation)[0, 0, ...] if sett.method == 'super-resolution': dat_yx = F.conv3d(dat_yx[None, None, ...], smo_ker, stride=ratio)[0, 0, ...] if scl != 0: dat_yx = _apply_scaling(dat_yx, scl, dim_thick) if diff: gr = grid_grad(dat_y, grid[None, ...], bound=sett.bound, extrapolate=extrapolate, interpolation=sett.interpolation)[0, 0, ...] if verbose >= 2: # Show images show_slices(torch.stack((dat_x, dat_yx, (dat_x - dat_yx)**2), 3), fig_num=666, colorbar=False, flip=False) # Double and mask msk = (dat_x != 0) # Compute matching term ll = 0.5 * tau * torch.sum( (dat_x[msk] - dat_yx[msk])**2, dtype=torch.float64) if diff: # Difference diff = dat_yx - dat_x msk = msk & (dat_yx != 0) diff[~msk] = 0 # Hessian Hes = torch.zeros(dim + (6, ), device=dat_x.device, dtype=torch.float32) Hes[:, :, :, 0] = gr[:, :, :, 0] * gr[:, :, :, 0] Hes[:, :, :, 1] = gr[:, :, :, 1] * gr[:, :, :, 1] Hes[:, :, :, 2] = gr[:, :, :, 2] * gr[:, :, :, 2] Hes[:, :, :, 3] = gr[:, :, :, 0] * gr[:, :, :, 1] Hes[:, :, :, 4] = gr[:, :, :, 0] * gr[:, :, :, 2] Hes[:, :, :, 5] = gr[:, :, :, 1] * gr[:, :, :, 2] if sett.method == 'super-resolution': Hes *= CtC[..., None] diff = F.conv_transpose3d(diff[None, None, ...], smo_ker, stride=ratio)[0, 0, ...] # Gradient gr *= diff[..., None] return ll, gr, Hes
def get_oriented_slice(image, dim=-1, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample a slice in a RAS system Parameters ---------- image : (..., *shape3) dim : int, default=-1 Index of spatial dimension to sample in the visualization space If RAS: -1 = axial / -2 = coronal / -3 = sagittal index : int, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : (..., *shape2) tensor Slice in the visualisation space. """ # preproc dim if isinstance(dim, str): dim = dim.lower()[0] if dim == 'a': dim = -1 if dim == 'c': dim = -2 if dim == 's': dim = -3 backend = utils.backend(image) # compute default space (mn/mx are in voxels) affine, shape = _get_default_native(affine, image.shape[-3:]) space, mn, mx = _get_default_space(affine, [shape], space, bbox) affine, shape = (affine[0], shape[0]) # compute default cursor (in voxels) if index is None: index = (mx + mn) / 2 else: index = torch.as_tensor(index) index = spatial.affine_matvec(spatial.affine_inv(space), index) # include slice to volume matrix shape = tuple(((mx-mn) + 1).round().int().tolist()) if dim == -1: # axial shift = [[1, 0, 0, - mn[0] + 1], [0, 1, 0, - mn[1] + 1], [0, 0, 1, - index[2]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = shape[:-1] index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1) elif dim == -2: # coronal shift = [[1, 0, 0, - mn[0] + 1], [0, 0, 1, - mn[2] + 1], [0, 1, 0, - index[1]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[0], shape[2]) index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1) elif dim == -3: # sagittal if not transpose_sagittal: shift = [[0, 0, 1, - mn[2] + 1], [0, 1, 0, - mn[1] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[2], shape[1]) index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1) else: shift = [[0, -1, 0, mx[1] + 1], [0, 0, 1, - mn[2] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[1], shape[2]) index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1) else: raise ValueError(f'Unknown dimension {dim}') # sample space = spatial.affine_rmdiv(space, shift) affine = spatial.affine_lmdiv(affine, space) affine = affine.to(**backend) grid = spatial.affine_grid(affine, [*shape, 1]) *channel, s0, s1, s2 = image.shape imshape = (s0, s1, s2) image = image.reshape([1, -1, *imshape]) image = spatial.grid_pull(image, grid[None], interpolation=interpolation, bound='dct2', extrapolate=False) image = image.reshape([*channel, *shape]) return ((image, index, space) if return_index and return_mat else (image, index) if return_index else (image, space) if return_mat else image)
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid if options.target: # If target is provided, we build a dense transformation field grid = build_from_target(options.target) oaffine = options.target.affine if options.output_unit[0] == 'v': grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid) grid = grid - spatial.identity_grid(grid.shape[:-1], **utils.backend(grid)) else: grid = grid - spatial.affine_grid( oaffine.to(**utils.backend(grid)), grid.shape[:-1]) io.volumes.savef(grid, options.output.format(ext='.nii.gz'), affine=oaffine) else: if len(options.transformations) > 1: raise RuntimeError('Something weird happened: ' 'multiple transforms and no target') io.transforms.savef(options.transformations[0].affine, options.output.format(ext='.lta'))
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None, interpolation=1, bound='dct2', extrapolate=False, device=None, verbose=True): """Apply the linear and non-linear components of the transform and reslice the image to the target space. Notes ----- .. The shape and general orientation of the moving image is kept untouched. .. The linear transform is composed with the original orientation matrix. .. The non-linear component is "wrapped" in the input space, where it is applied. .. This function writes a new file (it does not modify the input files in place). Parameters ---------- moving : ImageFile An object describing a moving image. fname : list of str Output filename for each input file of the moving image (since images can be encoded over multiple volumes) like : ImageFile An object describing the target space inv : bool, default=False True if we are warping the fixed image to the moving space. In the case, `moving` should be a `FixedImageFile` and `like` a `MovingImageFile`. Else it should be a `MovingImageFile` and `'like` a `FixedImageFile`. lin : (4, 4) tensor, optional Linear (or rather affine) transformation nonlin : dict, optional Non-linear displacement field, with keys: disp : (..., 3) tensor Displacement field (in voxels) affine : (4, 4) tensor Orientation matrix of the displacement field interpolation : int, default=1 bound : str, default='dct2' extrapolate : bool, default = False device : default='cpu' """ nonlin = nonlin or dict(disp=None, affine=None) prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate) moving_affine = moving.affine.to(device) fixed_affine = like.affine.to(device) if inv: # affine-corrected fixed space if lin is not None: fix2lin = affine_matmul(lin, fixed_affine) else: fix2lin = fixed_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param to moving nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(moving_affine, fix2lin) g = affine_grid(g, like.shape) else: # affine-corrected moving space if lin is not None: mov2nlin = affine_matmul(lin, moving_affine) else: mov2nlin = moving_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param voxels to moving voxels (warps moving to fixed) nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(mov2nlin, fixed_affine) g = affine_grid(g, like.shape) if moving.type == 'labels': prm['interpolation'] = 0 for file, ofname in zip(moving.files, fname): if verbose: print(f'Resliced: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, device=device) dat = dat.reshape([*file.shape, file.channels]) if g is not None: dat = utils.movedim(dat, -1, 0) dat = pull(dat, g, **prm) dat = utils.movedim(dat, 0, -1) io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, basis='affine', fwhm=None, joint=False, prm=None, max_iter_gn=100, max_iter_em=32, max_line_search=6, progressive=False, verbose=1): """ Parameters ---------- dat : (B, J|1, *spatial) tensor tpm : (B|1, K, *spatial) tensor affine : (4, 4) tensor affine_tpm : (4, 4) tensor weights : (B, 1, *spatial) tensor basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'} fwhm : float, default=J/32 joint : bool, default=False max_iter_gn : int, default=100 max_iter_em : int, default=32 max_line_search : int, default=12 progressive : bool, default=False Returns ------- mi : (B,) tensor aff : (B, 4, 4) tensor prm : (B, F) tensor """ dim = dat.dim() - 2 # ------------------------------------------------------------------ # RECURSIVE PROGRESSIVE FIT # ------------------------------------------------------------------ if progressive: nb_se = dim * (dim + 1) // 2 nb_aff = dim * (dim + 1) basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'} basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se} basis = convert_basis(basis) next_basis = basis_recursion.get(basis, None) if next_basis: *_, prm = fit_affine_tpm(dat, tpm, affine, affine_tpm, weights, basis=next_basis, fwhm=fwhm, joint=joint, prm=prm, max_iter_gn=max_iter_gn, max_iter_em=max_iter_em, max_line_search=max_line_search) B = len(dat) F = basis_nb_feat[basis] prm0 = prm prm = prm0.new_zeros([1 if joint else B, F]) if basis == 'SE': prm[:, :dim] = prm0[:, :dim] else: nb_se = dim * (dim + 1) // 2 prm[:, :nb_se] = prm0[:, :nb_se] if basis == 'Aff+': prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5)) basis_name = basis # ------------------------------------------------------------------ # PREPARE # ------------------------------------------------------------------ B = len(dat) if affine is None: affine = spatial.affine_default(dat.shape[-dim:]) if affine_tpm is None: affine_tpm = spatial.affine_default(tpm.shape[-dim:]) affine = affine.to(**utils.backend(tpm)) affine_tpm = affine_tpm.to(**utils.backend(tpm)) shape = dat.shape[-dim:] tpm = tpm.to(dat.device) basis = make_basis(basis, dim, **utils.backend(tpm)) F = len(basis) if prm is None: prm = tpm.new_zeros([1 if joint else B, F]) aff, gaff = linalg._expm(prm, basis, grad_X=True) em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights, verbose=verbose - 2) drv_opt = dict(weights=weights) pull_opt = dict(bound='replicate', extrapolate=True) # ------------------------------------------------------------------ # OPTIMIZE # ------------------------------------------------------------------ prior = None mi = torch.as_tensor(-float('inf')) delta = torch.zeros_like(prm) for n_iter in range(max_iter_gn): # -------------------------------------------------------------- # LINE SEARCH # -------------------------------------------------------------- prior0, prm0, mi0 = prior, prm, mi armijo = 1 success = False for n_ls in range(max_line_search): # --- take a step ------------------------------------------ prm = prm0 - armijo * delta # --- build transformation field --------------------------- aff, gaff = linalg._expm(prm, basis, grad_X=True) phi = lmdiv(affine_tpm, mm(aff, affine)) phi = spatial.affine_grid(phi, shape) # --- warp TPM --------------------------------------------- mov = spatial.grid_pull(tpm, phi, **pull_opt) # --- mutual info ------------------------------------------ mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt) mi = mi / Nm success = mi.sum() > mi0.sum() if verbose >= 2: end = '\n' if verbose >= 3 else '\r' happy = ':D' if success else ':(' print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}', end=end) if success: break armijo *= 0.5 # if verbose == 2: # print('') # -------------------------------------------------------------- # DID IT WORK? # -------------------------------------------------------------- if not success: prior, prm, mi = prior0, prm0, mi0 break # DEBUG # plot_registration(dat, mov, f'{basis_name} | {n_iter}') space = ' ' * max(0, 6 - len(basis_name)) if verbose >= 1: end = '\n' if verbose >= 2 else '\r' print( f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end) if mi.mean() - mi0.mean() < 1e-5: break # -------------------------------------------------------------- # GAUSS-NEWTON # -------------------------------------------------------------- # --- derivatives ---------------------------------------------- g, h = derivatives_intensity(mov, dat, prior, **drv_opt) # --- chain rule ----------------------------------------------- gmov = spatial.grid_grad(tpm, phi, **pull_opt) if joint and len(mov) == 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) else: gmov = gmov.expand([B, *gmov.shape[1:]]) gaff = lmdiv(affine_tpm, mm(gaff, affine)) g, h = chain_rule(g, h, gmov, gaff, maj=False) del gmov if joint and len(g) > 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) # --- Gauss-Newton --------------------------------------------- delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1) if verbose == 1: print('') return mi, aff, prm
def compose(self, orient_in, deformation, orient_mean, affine=None, orient_out=None, shape_out=None): """Composes a deformation defined in a mean space to an image space. Parameters ---------- orient_in : (4, 4) tensor Orientation of the input image deformation : (*shape_mean, 3) tensor Random deformation orient_mean : (4, 4) tensor Orientation of the mean space (where the deformation is) affine : (4, 4) tensor, default=identity Random affine orient_out : (4, 4) tensor, default=orient_in Orientation of the output image shape_out : sequence[int], default=shape_mean Shape of the output image Returns ------- grid : (*shape_out, 3) Voxel-to-voxel transform """ if orient_out is None: orient_out = orient_in if shape_out is None: shape_out = deformation.shape[:-1] if affine is None: affine = torch.eye(4, 4, device=orient_in.device, dtype=orient_in.dtype) shape_mean = deformation.shape[:-1] orient_in, affine, deformation, orient_mean, orient_out \ = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out) backend = utils.backend(deformation) eye = torch.eye(4, **backend) # Compose deformation on the right right_affine = spatial.affine_lmdiv(orient_mean, orient_out) if not (shape_mean == shape_out and right_affine.all_close(eye)): # the mean space and native space are not the same # we must compose the diffeo with a dense affine transform # we write the diffeo as an identity plus a displacement # (id + disp)(aff) = aff + disp(aff) # ------- # to displacement deformation = deformation - spatial.identity_grid( deformation.shape[:-1], **backend) trf = spatial.affine_grid(right_affine, shape_out) deformation = spatial.grid_pull(utils.movedim(deformation, -1, 0)[None], trf[None], bound='dft', extrapolate=True) deformation = utils.movedim(deformation[0], 0, -1) trf = trf + deformation # add displacement # Compose deformation on the left # the output of the diffeo(right) are mean_space voxels # we must compose on the left with `in\(aff(mean))` # ------- left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in), affine) left_affine = spatial.affine_matmul(left_affine, orient_mean) trf = spatial.affine_matvec(left_affine, trf) return trf
def _proj_apply(operator, dat, po, method='super-resolution', bound='zero', interpolation='linear'): """ Applies operator A, At or AtA (for denoising or super-resolution). Args: operator (string): Either 'A', 'At', 'AtA' or 'none'. dat (torch.tensor()): Image data (1, 1, X_in, Y_in, Z_in). po (_proj_op()): Encodes projection operator, has the following fields: po.mat_x: Low-res affine matrix. po.mat_y: High-res affine matrix. po.mat_yx: Intermediate affine matrix. po.dim_x: Low-res image dimensions. po.dim_y: High-res image dimensions. po.dim_yx: Intermediate image dimensions. po.ratio: The ratio (low-res voxel_size)/(high-res voxel_size). po.smo_ker: Smoothing kernel (slice-profile). method (string): Either 'denoising' or 'super-resolution' (default). bound (str, optional): Bound for nitorch push/pull, defaults to 'zero'. interpolation (int, optional): Interpolation order, defaults to linear. Returns: dat (torch.tensor()): Projected image data (1, 1, X_out, Y_out, Z_out). """ # Sanity check if operator not in ['A', 'At', 'AtA', 'none']: raise ValueError('Undefined operator') if method not in ['denoising', 'super-resolution']: raise ValueError('Undefined method') if operator == 'none': # No projection return dat # Get data type and device dtype = dat.dtype device = dat.device # Parse required projection info mat_x = po.mat_x mat_y = po.mat_y mat_yx = po.mat_yx rigid = po.rigid dim_x = po.dim_x dim_y = po.dim_y dim_yx = po.dim_yx ratio = po.ratio smo_ker = po.smo_ker scl = po.scl dim_thick = po.dim_thick if method == 'super-resolution': dim = dim_yx mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx elif method == 'denoising': dim = dim_x mat = rigid.mm(mat_x).solve(mat_y)[0] # mat_y\rigid*mat_x # Smoothing operator if len(ratio) == 3: # 3D conv = lambda x: F.conv3d(x, smo_ker, stride=ratio) conv_transpose = lambda x: F.conv_transpose3d(x, smo_ker, stride=ratio) else: # 2D conv = lambda x: F.conv2d(x, smo_ker, stride=ratio) conv_transpose = lambda x: F.conv_transpose2d(x, smo_ker, stride=ratio) # Get grid grid = affine_grid(mat.type(dat.dtype), dim, jitter=False)[None, ...] # Apply projection if method == 'super-resolution': extrapolate = False if operator == 'A': dat = grid_pull(dat, grid, bound=bound, extrapolate=extrapolate, interpolation=interpolation) dat = conv(dat) if scl != 0: dat = _apply_scaling(dat, scl, dim_thick) elif operator == 'At': if scl != 0: dat = _apply_scaling(dat, scl, dim_thick) dat = conv_transpose(dat) dat = grid_push(dat, grid, shape=dim_y, bound=bound, extrapolate=extrapolate, interpolation=interpolation) elif operator == 'AtA': dat = grid_pull(dat, grid, bound=bound, extrapolate=extrapolate, interpolation=interpolation) dat = conv(dat) if scl != 0: dat = _apply_scaling(dat, 2 * scl, dim_thick) dat = conv_transpose(dat) dat = grid_push(dat, grid, shape=dim_y, bound=bound, extrapolate=extrapolate, interpolation=interpolation) elif method == 'denoising': extrapolate = False if operator == 'A': dat = grid_pull(dat, grid, bound=bound, extrapolate=extrapolate, interpolation=interpolation) elif operator == 'At': dat = grid_push(dat, grid, shape=dim_y, bound=bound, extrapolate=extrapolate, interpolation=interpolation) elif operator == 'AtA': dat = grid_pull(dat, grid, bound=bound, extrapolate=extrapolate, interpolation=interpolation) dat = grid_push(dat, grid, shape=dim_y, bound=bound, extrapolate=extrapolate, interpolation=interpolation) return dat