def forward(self, source, target, source_affine=None, target_affine=None): """ Parameters ---------- source : (sX, sY, sZ) tensor or str target : (tX, tY, tZ) tensor or str source_affine : (4, 4) tensor, optional target_affine : (4, 4) tensor, optional Returns ------- warped : (tX, tY, tZ) tensor Source warped to target velocity : (vX, vY, vZ, 3) tensor Stationary velocity field affine : (4, 4) tensor, optional Affine of the velocity space """ if self.verbose: print('Preprocessing... ', end='', flush=True) source, source_affine, source_orig, source_affine_orig \ = self.load(source, source_affine) target, target_affine, target_orig, target_affine_orig \ = self.load(target, target_affine) source = spatial.reslice(source, source_affine, target_affine, target.shape) if self.verbose: print('done.', flush=True) print('Registering... ', end='', flush=True) source = source[None, None] target = target[None, None] warped, vel, grid = super().forward(source, target) if self.verbose: print('done.', flush=True) del source, target, warped vel = vel[0] grid = grid[0] grid -= spatial.identity_grid(grid.shape[:-1], dtype=grid.dtype, device=grid.device) right_affine = target_affine.inverse() @ target_affine_orig right_affine = spatial.affine_grid(right_affine, target_orig.shape) grid = spatial.grid_pull(utils.movedim(grid, -1, 0), right_affine, bound='nearest', extrapolate=True) grid = utils.movedim(grid, 0, -1).add_(right_affine) left_affine = source_affine_orig.inverse() @ target_affine grid = spatial.affine_matvec(left_affine, grid) warped = spatial.grid_pull(source_orig, grid) return warped, vel, target_affine
def pull1d(img, grid, grad=False, **kwargs): """Pull an image by a transform along the last dimension Parameters ---------- img : (K, *spatial) tensor, Image grid : (*spatial) tensor, Sampling grid grad : bool, Sample gradients Returns ------- warped_img : (K, *spatial) tensor warped_grad : (K, *spatial) tensor, if `grad` """ if grid is None: if grad: bound = kwargs.get('bound', 'dft') return img, diff1d(img, dim=-1, bound=bound, side='c') else: return img, None kwargs.setdefault('extrapolate', True) kwargs.setdefault('bound', 'dft') img, grid = img.unsqueeze(-2), grid.unsqueeze(-1) warped = grid_pull(img, grid, **kwargs).squeeze(-2) if not grad: return warped grad = grid_grad(img, grid, **kwargs) grad = grad.squeeze(-1).squeeze(-2) return warped, grad
def load_and_pull(volume, aff, shape, dtype=None, device=None): """ Parameters ---------- volume : Volume3D aff : (D+1,D+1) tensor shape : (D,) tuple Returns ------- dat : tensor """ backend = dict(dtype=dtype or aff.dtype, device=device or aff.device) aff = aff.to(**backend) identity = torch.eye(aff.shape[-1], **backend) fdata = volume.fdata(cache=False, **backend) inshape = fdata.shape inaff = volume.affine.to(**backend) aff = core.linalg.lmdiv(inaff, aff) if torch.allclose(aff, identity) and tuple(shape) == tuple(inshape): return fdata else: grid = spatial.affine_grid(aff, shape) return spatial.grid_pull(fdata[None, None, ...], grid[None, ...])[0, 0]
def _resample_inplane(x, sett): """Force in-plane resolution of observed data to be greater or equal to recon vx. """ if sett.force_inplane_res and sett.max_iter > 0: I = torch.eye(4, device=sett.device, dtype=torch.float64) for c in range(len(x)): for n in range(len(x[c])): # get image data dat = x[c][n].dat[None, None, ...] mat_x = x[c][n].mat dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64) vx_x = voxel_size(mat_x) # make grid D = I.clone() for i in range(3): D[i, i] = sett.vx / vx_x[i] if D[i, i] < 1.0: D[i, i] = 1 if float((I - D).abs().sum()) < 1e-4: continue mat_x = mat_x.matmul(D) dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist() grid = affine_grid(D.type(dat.dtype), dim_x) # resample dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0) # do label if x[c][n].label is not None: x[c][n].label[0] = _warp_label(x[c][n].label[0], grid) # assign x[c][n].dat = dat[0, 0, ...] x[c][n].mat = mat_x x[c][n].dim = dim_x return x
def transform_pointset_dense(points, grid, type='grid', bound='dct2'): """Transform a pointset Points must already be expressed in "grid voxels" coordinates. Parameters ---------- points : (n, dim) tensor Set of coordinates, in voxel space grid : (*spatial, dim) tensor Dense transformation or displacement grid, in voxel space type : {'grid', 'disp'}, defualt='grid' Transformation or displacement bound : str, default='dct2' Boundary conditions for out-of-bounds data Returns ------- points : (n, dim) tensor Transformed coordinates """ dim = grid.shape[-1] points = utils.unsqueeze(points, 0, dim) grid = utils.movedim(grid, -1, 0)[None] delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True) delta = utils.movedim(delta, 1, -1) if type == 'disp': points = points + delta else: points = delta points = utils.squeeze(points, -2, dim - 1).squeeze(0) return points
def _init_y_dat(x, y, sett): """ Make initial guesses of reconstucted image(s) using b-spline interpolation, with averaging if more than one observation per channel. """ dim_y = y[0].dim mat_y = y[0].mat for c in range(len(x)): dat_y = torch.zeros(dim_y, dtype=torch.float32, device=sett.device) num_x = len(x[c]) sm = torch.zeros_like(dat_y) for n in range(num_x): # Get image data dat = x[c][n].dat[None, None, ...] # Make output grid mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y grid = affine_grid(mat.type(dat.dtype), dim_y) # Do resampling mn = torch.min(dat) mx = torch.max(dat) dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=1) dat[dat < mn] = mn dat[dat > mx] = mx sm = sm + (dat[0, 0, ...].round() != 0) dat_y = dat_y + dat[0, 0, ...] sm[sm == 0] = 1 y[c].dat = dat_y / sm return y
def warp_label(label, grid): """Warp label image according to grid. """ ndim = len(label.shape[2:]) dtype_seg = label.dtype if dtype_seg not in (torch.half, torch.float, torch.double): # hard labels to one-hot labels n_batch = label.shape[0] u_labels = label.unique() n_labels = len(u_labels) label_w = torch.zeros(( n_batch, n_labels, ) + tuple(label.shape[2:]), device=label.device, dtype=torch.float32) for i, l in enumerate(u_labels): label_w[..., i, ...] = label == l else: label_w = label # warp label_w = spatial.grid_pull(label_w, grid, bound='dct2', extrapolate=True, interpolation=1) if dtype_seg not in (torch.half, torch.float, torch.double): # one-hot labels to hard labels label_w = label_w.argmax(dim=1, keepdim=True).type(dtype_seg) else: # normalise one-hot labels label_w = label_w / (label_w.sum(dim=1, keepdim=True) + eps()) return label_w
def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def warp_image(image, grid): """Warp image according to grid. """ image = spatial.grid_pull(image, grid, bound='dct2', extrapolate=True, interpolation=1) return image
def _deform_1d(img, disp, grad=False): img = utils.movedim(img, 0, -2) disp = disp.unsqueeze(-1) disp = spatial.add_identity_grid(disp) wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True) wrp = utils.movedim(wrp, -2, 0) if not grad: return wrp, None grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True) grd = utils.movedim(grd.squeeze(-1), -2, 0) return wrp, grd
def _crop_y(y, sett): """ Crop output images FOV to a fixed dimension Args: y (_output()): _output data. Returns: y (_output()): Cropped output data. """ if not sett.crop: return y device = sett.device # Output image information mat_y = y[0].mat vx_y = voxel_size(mat_y) # Define cropped FOV mat_mu, dim_mu = _bb_atlas('atlas_t1', fov=sett.fov, dtype=torch.float64, device=device) # Modulate atlas with voxel size mat_vx = torch.diag(torch.cat(( vx_y, torch.ones(1, dtype=torch.float64, device=device)))) mat_mu = mat_mu.mm(mat_vx) dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze() # Make output grid M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype) grid = affine_grid(M, dim_mu)[None, ...] # Crop for c in range(len(y)): y[c].dat = grid_pull(y[c].dat[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] # Do labels? if y[c].label is not None: y[c].label = grid_pull(y[c].label[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] y[c].mat = mat_mu y[c].dim = tuple(dim_mu.int().tolist()) return y
def eval_position(self, t): """Evaluate the position at a given (batched) time""" # convert (0, 1) to (0, n) shape = t.shape t = t.flatten() t = t.clamp(0, 1) * (len(self.waypoints) - 1) # interpolate y = self.coeff.T # [D, K] t = t.unsqueeze(-1) # [N, 1] x = grid_pull(y, t, interpolation=self.order, bound=self.bound) x = x.T # [N, D] x = x.reshape([*shape, x.shape[-1]]) return x
def smart_pull_grid(vel, grid, type='disp', *args, **kwargs): """Interpolate a velocity/grid/displacement field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- vel : ([batch], *spatial, ndim) tensor Velocity grid : ([batch], *spatial, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_vel : ([batch], *spatial, ndim) tensor Velocity """ if grid is None or vel is None: return vel kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = vel.shape[-1] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel = vel - id vel = utils.movedim(vel, -1, -dim - 1) vel_no_batch = vel.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if vel_no_batch: vel = vel[None] if grid_no_batch: grid = grid[None] vel = spatial.grid_pull(vel, grid, *args, **kwargs) vel = utils.movedim(vel, -dim - 1, -1) if vel_no_batch: vel = vel[0] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel += id return vel
def eval_radius(self, t): """Evaluate the radius at a given (batched) time""" if not torch.is_tensor(self.radius): return self.radius # convert (0, 1) to (0, n) shape = t.shape t = t.flatten() t = t.clamp(0, 1) * (len(self.waypoints) - 1) # interpolate y = self.coeff_radius # [K] t = t.unsqueeze(-1) # [N, 1] x = grid_pull(y, t, interpolation=self.order, bound=self.bound) x = x.reshape(shape) return x
def _jhistc_backward(g, x, w=None, order=0, bound='replicate', extrapolate=True, gradx=True, gradw=False): """Compute derivative of the joint histogram. The input must already be a soft mapping to bins indices. Parameters ---------- g : (b, bins, bins) tensor x : (b, n, 2) tensor w : ([b], n) tensor, optional order : int, default=0 bound : {'zero', 'nearest'}, default='nearest' extrapolate : bool, default=True gradx : bool, default=True gradw : bool, default=False Returns ------- gx : (b, n, 2) tensor, if gradx gw : ([b], n) tensor, if gradw """ extrapolate = 1 if extrapolate else 2 opt = dict(interpolation=order, bound=bound, extrapolate=extrapolate) x = x.unsqueeze(-3) # make 2d spatial g = g.unsqueeze(-3) # add channel dimension out = [] if gradx: gx = grid_grad(g, x, **opt) gx = gx.squeeze(-3).squeeze(-3) if w is not None: gx *= w.unsqueeze(-1) out.append(gx) if gradw and w is not None: gw = grid_pull(g, x, **opt) gw = gw.squeeze(-2).squeeze(-2) # drop spatial + channel out.append(gw) elif gradw: out.append(None) return out[0] if len(out) == 1 else tuple(out)
def intensity_to_rgb(image, min=None, max=None, colormap='gray', n=256, eq=False): """Colormap an intensity image Parameters ---------- image : (*batch, H, W) tensor A (batch of) 2d image min : tensor_like, optional Minimum value. Should be broadcastable to batch. Default: min of image for each batch element. max : tensor_like, optional Maximum value. Should be broadcastable to batch. Default: max of image for each batch element. colormap : str or (K, 3) tensor, default='gray' A colormap or the name of a matplotlib colormap. n : int, default=256 Number of color levels to use. eq : bool or {'linear', 'quadratic', 'log', None}, default=None Apply histogram equalization. If 'quadratic' or 'log', the histogram of the transformed signal is equalized. Returns ------- rgb : (*batch, H, W, 3) tensor A (batch of) of RGB image. """ image = torch.as_tensor(image).detach() image = intensity_preproc(image, min=min, max=max, eq=eq) # map colormap = _get_colormap_intensity(colormap, n, image.dtype, image.device) shape = image.shape image = image.mul_(n - 1).clamp_(0, n - 1) image = image.reshape([1, -1, 1]) colormap = colormap.T.reshape([1, 3, -1]) image = spatial.grid_pull(colormap, image) image = image.reshape([3, *shape]) image = utils.movedim(image, 0, -1) return image
def forward(self, x, grid): """ Parameters ---------- x : (batch, channel, *spatial_in) tensor Input image to deform grid : (batch, *spatial_out, len(spatial_in)) tensor Transformation grid Returns ------- pulled : (batch, channel, *spatial_out) tensor Deformed image. """ return spatial.grid_pull(x, grid, self.interpolation, self.bound, self.extrapolate)
def pull1d(img, grid, dim, grad=False, **kwargs): if grid is None: if grad: bound = kwargs.get('bound', 'dft') return img, spatial.diff1d(img, dim=dim, bound=bound, side='c') else: return img, None kwargs.setdefault('extrapolate', True) kwargs.setdefault('bound', 'dft') img = core.utils.movedim(img, dim, -1).unsqueeze(-2) grid = core.utils.movedim(grid, dim, -1).unsqueeze(-1) warped = spatial.grid_pull(img, grid, **kwargs) warped = core.utils.movedim(warped.squeeze(-2), -1, dim) if not grad: return warped, None grad = spatial.grid_grad(img, grid, **kwargs) grad = core.utils.movedim(grad.squeeze(-1).squeeze(-2), -1, dim) return warped, grad
def _warp_label(label, grid): """Warp a label image. """ u = label.unique() if u.numel() > 255: raise ValueError('Too many label values.') f1 = torch.zeros(grid.shape[:3], device=label.device, dtype=label.dtype) p1 = f1.clone() for u1 in u: g0 = (label == u1).float() tmp = grid_pull(g0[None, None, ...], grid[None, ...], bound='zero', extrapolate=False, interpolation=1)[0, 0, ...] msk1 = tmp > p1 p1[msk1] = tmp[msk1] f1[msk1] = u1 return f1
def _reslice_dat_3d(dat, affine, dim_out, interpolation='linear', bound='zero', extrapolate=False): """Reslice 3D image data. Parameters ---------- dat : (Xi, Yi, Zi), tensor_like Input image data. affine : (4, 4), tensor_like Affine transformation that maps from voxels in output image to voxels in input image. dim_out : (Xo, Yo, Zo), list or tuple Output image dimensions. interpolation : str, default='linear' Interpolation order. bound : str, default='zero' Boundary condition. extrapolate : bool, default=False Extrapolate out-of-bounds data. Returns ------- dat : (dim_out), tensor_like Resliced image data. """ if len(dat.shape) != 3: raise ValueError('Input error: len(dat.shape) != 3') grid = affine_grid(affine, dim_out).type(dat.dtype) grid = grid[None, ...] dat = dat[None, None, ...] dat = grid_pull(dat, grid, bound=bound, interpolation=interpolation, extrapolate=extrapolate) dat = dat[0, 0, ...] return dat
def smart_pull(tensor, grid): """Pull iff grid is defined (+ add/remove batch dim). Parameters ---------- tensor : (channels, *input_shape) tensor Input volume grid : (*output_shape, D) tensor or None Sampling grid Returns ------- pulled : (channels, *output_shape) tensor Sampled volume """ if grid is None: return tensor return spatial.grid_pull(tensor[None, ...], grid[None, ...])[0]
def eval_grad_position(self, t): """Evaluate position and its gradient wrt time""" # convert (0, 1) to (0, n) shape = t.shape t = t.flatten() t = t.clamp(0, 1) * (len(self.waypoints) - 1) # interpolate y = self.coeff.T # [D, K] t = t.unsqueeze(-1) # [N, 1] x = grid_pull(y, t, interpolation=self.order, bound=self.bound) x = x.T # [N, D] g = grid_grad(y, t, interpolation=self.order, bound=self.bound) g = g.squeeze(-1).T # [N, D] x = x.reshape([*shape, x.shape[-1]]) g = g.reshape([*shape, g.shape[-1]]) g *= (len(self.waypoints) - 1) return x, g
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def smart_pull_jac(jac, grid, *args, **kwargs): """Interpolate a jacobian field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- jac : ([batch], *spatial_in, ndim, ndim) tensor Jacobian field grid : ([batch], *spatial_out, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_jac : ([batch], *spatial_out, ndim) tensor Jacobian field """ if grid is None or jac is None: return jac kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = jac.shape[-1] jac = jac.reshape([*jac.shape[:-2], dim * dim]) # collapse matrix jac = utils.movedim(jac, -1, -dim - 1) jac_no_batch = jac.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if jac_no_batch: jac = jac[None] if grid_no_batch: grid = grid[None] jac = spatial.grid_pull(jac, grid, *args, **kwargs) jac = utils.movedim(jac, -dim - 1, -1) jac = jac.reshape([*jac.shape[:-1], dim, dim]) if jac_no_batch: jac = jac[0] return jac
def pull(image, grid, interpolation=1, bound='dct2', extrapolate=False): """Sample a multi-channel image Parameters ---------- image : (channel, *inshape) tensor grid : (*outshape, dim) tensor Returns ------- imageout : (channel, *outshape) """ image = image[None] grid = grid[None] image = grid_pull(image, grid, interpolation=interpolation, bound=bound, extrapolate=extrapolate)[0] return image
def pull_grid(gridin, grid, interpolation=1, bound='dft', extrapolate=True): """Sample a displacement field. Parameters ---------- gridin : (*inshape, dim) tensor grid : (*outshape, dim) tensor Returns ------- gridout : (*outshape, dim) tensor """ gridin = movedim(gridin, -1, 0)[None] grid = grid[None] gridout = grid_pull(gridin, grid, interpolation=interpolation, bound=bound, extrapolate=extrapolate) gridout = movedim(gridout[0], 0, -1) return gridout
def forward(self, x, grid, **overload): """ Parameters ---------- x : (batch, channel, *spatial_in) tensor Input image to deform grid : (batch, *spatial_out, len(spatial_in)) tensor Transformation grid overload : dict All parameters defined at build time can be overridden at call time. Returns ------- pulled : (batch, channel, *spatial_out) tensor Deformed image. """ interpolation = overload.get('interpolation', self.interpolation) bound = overload.get('bound', self.bound) extrapolate = overload.get('extrapolate', self.extrapolate) return spatial.grid_pull(x, grid, interpolation, bound, extrapolate)
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]