def smart_pull_grid(vel, grid, type='disp', *args, **kwargs): """Interpolate a velocity/grid/displacement field. Notes ----- Defaults differ from grid_pull: - bound -> dft - extrapolate -> True Parameters ---------- vel : ([batch], *spatial, ndim) tensor Velocity grid : ([batch], *spatial, ndim) tensor Transformation field kwargs : dict Options to ``grid_pull`` Returns ------- pulled_vel : ([batch], *spatial, ndim) tensor Velocity """ if grid is None or vel is None: return vel kwargs.setdefault('bound', 'dft') kwargs.setdefault('extrapolate', True) dim = vel.shape[-1] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel = vel - id vel = utils.movedim(vel, -1, -dim - 1) vel_no_batch = vel.dim() == dim + 1 grid_no_batch = grid.dim() == dim + 1 if vel_no_batch: vel = vel[None] if grid_no_batch: grid = grid[None] vel = spatial.grid_pull(vel, grid, *args, **kwargs) vel = utils.movedim(vel, -dim - 1, -1) if vel_no_batch: vel = vel[0] if type == 'grid': id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) vel += id return vel
def make_data(shape, device, dtype): id = identity_grid(shape, dtype=dtype, device=device) id = id[None, ...] # add batch dimension disp = torch.rand(id.shape, device=device, dtype=dtype) grid = id + disp vol = torch.rand((1, 1) + shape, device=device, dtype=dtype) return vol, grid
def smart_grad(tensor, grid, **opt): """Pull gradients iff grid is defined (+ add/remove batch dim). Parameters ---------- tensor : (channels, *input_shape) tensor Input volume grid : (*output_shape, D) tensor or None Sampling grid Returns ------- pulled : (channels, *output_shape) tensor Sampled volume """ # if grid is None: # opt.pop('extrapolate', None) # opt.pop('interpolation', None) # return spatial.diff(tensor, dim=3, **opt) if grid is None: grid = spatial.identity_grid(tensor.shape[-3:], dtype=tensor.dtype, device=tensor.device) out = spatial.grid_grad(tensor, grid, **opt) return out
def exp(self, velocity, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacement). """ backend = dict(dtype=velocity.dtype, device=velocity.device) # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **backend) return grid
def exp(prm): disp = spatial.resize_grid(prm, type='displacement', shape=target.shape[2:], interpolation=3, bound='dft') grid = disp + spatial.identity_grid(target.shape[2:], **backend) return disp, grid
def draw_curves(shape, s, mode='gaussian', tiny=0, **kwargs): """Draw multiple BSpline curves Parameters ---------- shape : list[int] s : list[BSplineCurve] mode : {'binary', 'gaussian'} Returns ------- x : (*shape) tensor Drawn curve lab : (*shape) tensor[int] Label of closest curve """ s = list(s) x = identity_grid(shape, **utils.backend(s[0].waypoints)) n = len(s) tiny = tiny / n l = x.new_zeros(shape, dtype=torch.long) if mode[0].lower() == 'b': s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c = d <= r l[c] = 1 cnt = 1 while s: cnt += 1 s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c.bitwise_or_(d <= r) l[d <= r] = cnt else: s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c = dist_to_prob(d, r, tiny) l.fill_(1) cnt = 1 p = c.clone() c = c.neg_().add_(1) while s: cnt += 1 s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c1 = dist_to_prob(d, r, tiny) l[c1 > p] = cnt p = torch.maximum(c1, p) c.mul_(c1.neg_().add_(1)) c = c.neg_().add_(1) return c, l
def gauss_kernel(f, dim): s = f / math.sqrt(8. * math.log(2.)) + 1E-7 shape = math.ceil(4 * s) shape = shape + (shape % 2 == 0) g = identity_grid([shape] * dim) g -= shape / 2 g = g.square_().sum(-1) g *= (-0.5 / (s**2)) g.exp_() g /= g.sum() return g
def _identity(x): """Build an identity grid with same shape/backend as a tensor. The grid is built such that coordinate zero is at the center of the FOV.""" shape = x.shape[2:] backend = dict(dtype=x.dtype, device=x.device) grid = spatial.identity_grid(shape, **backend) grid -= torch.as_tensor(shape, **backend) / 2. grid /= torch.as_tensor(shape, **backend) / 2. grid = last2channel(grid[None, ...]) return grid
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch shape. Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- grid : (batch, *shape, 3) tensor Resampling grid """ shape = overload.get('shape', self.grid.velocity.field.shape) dtype = overload.get('dtype', self.grid.velocity.field.dtype) device = overload.get('device', self.grid.velocity.field.device) backend = dict(dtype=dtype, device=device) if self.grid.velocity.field.amplitude == 0: grid = identity_grid(shape, **backend) else: grid = self.grid(batch, shape=shape, **backend) dtype = grid.dtype device = grid.device backend = dict(dtype=dtype, device=device) shape = grid.shape[1:-1] dim = len(shape) aff = self.affine(batch, dim=dim, **backend) # shift center of rotation aff_shift = torch.cat(( torch.eye(dim, **backend), torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)), dim=1) aff_shift = as_euclidean(aff_shift) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = utils.unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def forward(self, source, target, source_affine=None, target_affine=None): """ Parameters ---------- source : (sX, sY, sZ) tensor or str target : (tX, tY, tZ) tensor or str source_affine : (4, 4) tensor, optional target_affine : (4, 4) tensor, optional Returns ------- warped : (tX, tY, tZ) tensor Source warped to target velocity : (vX, vY, vZ, 3) tensor Stationary velocity field affine : (4, 4) tensor, optional Affine of the velocity space """ if self.verbose: print('Preprocessing... ', end='', flush=True) source, source_affine, source_orig, source_affine_orig \ = self.load(source, source_affine) target, target_affine, target_orig, target_affine_orig \ = self.load(target, target_affine) source = spatial.reslice(source, source_affine, target_affine, target.shape) if self.verbose: print('done.', flush=True) print('Registering... ', end='', flush=True) source = source[None, None] target = target[None, None] warped, vel, grid = super().forward(source, target) if self.verbose: print('done.', flush=True) del source, target, warped vel = vel[0] grid = grid[0] grid -= spatial.identity_grid(grid.shape[:-1], dtype=grid.dtype, device=grid.device) right_affine = target_affine.inverse() @ target_affine_orig right_affine = spatial.affine_grid(right_affine, target_orig.shape) grid = spatial.grid_pull(utils.movedim(grid, -1, 0), right_affine, bound='nearest', extrapolate=True) grid = utils.movedim(grid, 0, -1).add_(right_affine) left_affine = source_affine_orig.inverse() @ target_affine grid = spatial.affine_matvec(left_affine, grid) warped = spatial.grid_pull(source_orig, grid) return warped, vel, target_affine
def exp(self, velocity, affine=None, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field affine : (batch, nb_prm) Affine parameters displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacment). """ info = {'dtype': velocity.dtype, 'device': velocity.device} # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if affine is not None: # exponentiate affine_prm = affine affine = [] for prm in affine_prm: affine.append(self.affexp(prm)) affine = torch.stack(affine, dim=0) # shift center of rotation affine_shift = torch.cat( (torch.eye(self.dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) # compose affine = unsqueeze(affine, dim=-3, ndim=self.dim) lin = affine[..., :self.dim, :self.dim] off = affine[..., :self.dim, -1] grid = matvec(lin, grid) + off if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **info) return grid
def smalldef(disp): """Transform a displacement grid into a transformation grid Parameters ---------- disp : (*shape, dim) Returns ------- grid : (*shape, dim) """ id = identity_grid(disp.shape[:-1], dtype=disp.dtype, device=disp.device) return disp + id
def _draw_curves_inv(shape, s, tiny=0): """prod_k (1 - p_k)""" s = list(s) x = identity_grid(shape, **utils.backend(s[0].waypoints)) s1 = s.pop(0) t, d = min_dist(x, s1) r = s1.eval_radius(t) c = dist_to_prob(d, r, tiny=tiny).neg_().add_(1) while s: s1 = s.pop(0) t, d = min_dist(x, s1) r = s1.eval_radius(t) c.mul_(dist_to_prob(d, r, tiny=tiny).neg_().add_(1)) return c
def _laplacian_freq(shape, **backend): """ Compute Fourier squared frequency on the lattice and its inverse. """ dim = len(shape) shape = torch.as_tensor(shape, **backend) g = spatial.identity_grid(shape, **backend) g -= shape // 2 g /= shape g = g.square_().sum(-1) if fft._torch_has_old_fft: g = g.unsqueeze(-1) g = fft.ifftshift(g, dim=list(range(dim))) ig = g.reciprocal() ig[(0,) * dim] = 0 return g, ig
def roi_closing(label, radius=10, dim=None): """Performs a multi-label morphological closing. Parameters ---------- label : (..., *spatial) tensor[int] Volume of labels. radius : float, default=1 Radius of the structuring element (in voxels) dim : int, default=label.dim() Number of spatial dimensions Returns ------- closed_label : tensor[int] """ from scipy.ndimage import distance_transform_edt, binary_closing dim = dim or label.dim() closest_label = torch.zeros_like(label) closest_dist = label.new_full(label.shape, float('inf'), dtype=torch.float) dist = torch.empty_like(closest_dist) for l in label.unique(): if l == 0: continue if label.dim() == dim: dist = torch.as_tensor(distance_transform_edt(label != l)) elif label.dim() == dim + 1: for z in range(len(dist)): dist[z] = torch.as_tensor( distance_transform_edt(label[z] != l)) else: raise NotImplementedError closest_label[dist < closest_dist] = l closest_dist = torch.min(closest_dist, dist) struct = spatial.identity_grid([2 * radius + 1] * dim).sub_(radius) struct = struct.square().sum(-1).sqrt() <= radius struct = utils.unsqueeze(struct, 0, label.dim() - dim) mask = binary_closing(label > 0, struct) mask = torch.as_tensor(mask).bitwise_not_() closest_label[mask] = 0 return closest_label
def voxelize_rois(rois, shape, roi_to_vox=None, device=None): """Create a volume of labels from a parametric ROI. Parameters ---------- rois : dict Object returned by `read_asc` shape : sequence[int] roi_to_vox : (d+1, d+1) tensor Returns ------- roi : (*shape) tensor[int] names : list[str] """ out = torch.empty(shape, dtype=torch.long) grid = spatial.identity_grid(shape[:2], device=device) if roi_to_vox is not None: roi_to_vox = roi_to_vox.to(device=device) names = list(rois['regions'].keys()) for l, (name, shapes) in enumerate(rois['regions'].items()): print(name) label = l + 1 for i, shape in enumerate(shapes): print(i + 1, '/', len(shapes), end='\r') vertices = [[p['x'], p['y'], p['z']] for p in shape['points']] vertices = torch.as_tensor(vertices, device=device) if roi_to_vox is not None: vertices = spatial.affine_matvec(roi_to_vox, vertices) z = math.round(vertices[0, 2]).int().item() if not (0 <= z < out.shape[-1]): print('Contour not in FOV. Skipping it...') continue vertices = vertices[:, :2] faces = [(i, i + 1 if i + 1 < len(vertices) else 0) for i in range(len(vertices))] mask = is_inside(grid, vertices, faces).cpu() out[..., z][mask] = label print('') return out, names
def affine_grid_backward(*grad_hess, grid=None): """Converts ∇ wrt dense displacement into ∇ wrt affine matrix g = affine_grid_backward(g, [grid=None]) g, h = affine_grid_backward(g, h, [grid=None]) Parameters ---------- grad : (..., *spatial, dim) tensor Gradient with respect to a dense displacement. hess : (..., *spatial, dim*(dim+1)//2) tensor, optional Hessian with respect to a dense displacement. grid : (*spatial, dim) tensor, optional Pre-computed identity grid Returns ------- grad : (..., dim, dim+1) tensor Gradient with respect to an affine matrix hess : (..., dim, dim+1, dim, dim+1) tensor, optional Hessian with respect to an affine matrix """ has_hess = len(grad_hess) > 1 grad, *hess = grad_hess hess = hess.pop(0) if hess else None del grad_hess dim = grad.shape[-1] shape = grad.shape[-dim - 1:-1] batch = grad.shape[:-dim - 1] nvox = py.prod(shape) if grid is None: grid = spatial.identity_grid(shape, **utils.backend(grad)) grid = grid.reshape([1, nvox, dim]) grad = grad.reshape([-1, nvox, dim]) if hess is not None: hess = hess.reshape([-1, nvox, dim * (dim + 1) // 2]) grad, hess = _affine_grid_backward_gh(grid, grad, hess) hess = hess.reshape([*batch, dim, dim + 1, dim, dim + 1]) else: grad = _affine_grid_backward_g(grid, grad) grad = grad.reshape([*batch, dim, dim + 1]) return (grad, hess) if has_hess else grad
def ffd_exp(prm, shape, order=3, bound='dft', returns='disp'): """Transform FFD parameters into a displacement or transformation grid. Parameters ---------- prm : (..., *spatial, dim) FFD parameters shape : sequence[int] Exponentiated shape order : int, default=3 Spline order bound : str, default='dft' Boundary condition returns : {'disp', 'grid', 'disp+grid'}, default='grid' What to return: - 'disp' -> displacement grid - 'grid' -> transformation grid Returns ------- disp : (..., *shape, dim), optional Displacement grid grid : (..., *shape, dim), optional Transformation grid """ backend = dict(dtype=prm.dtype, device=prm.device) dim = prm.shape[-1] batch = prm.shape[:-(dim + 1)] prm = prm.reshape([-1, *prm.shape[-(dim + 1):]]) disp = resize_grid(prm, type='displacement', shape=shape, interpolation=order, bound=bound) disp = disp.reshape(batch + disp.shape[1:]) grid = disp + identity_grid(shape, **backend) if 'disp' in returns and 'grid' in returns: return disp, grid elif 'disp' in returns: return disp elif 'grid' in returns: return grid
def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False): """Convert derivatives wrt warped image in loss space to to derivatives wrt parameters parameters: g (tensor) : gradient wrt warped image h (tensor) : hessian wrt warped image moving (Image) : moving image phi (tensor) : dense (exponentiated) displacement field left (matrix) : left affine right (matrix) : right affine inv (bool) : whether we're in a backward symmetric pass returns: g (tensor) : pushed gradient h (tensor) : pushed hessian gmu (tensor) : rotated spatial gradients """ if inv: g = g.neg_() # build bits of warp dim = phi.shape[-1] fixed_shape = g.shape[-dim:] moving_shape = moving.shape # differentiate wrt δ in: Left o Phi o (Id + δ) o Right # we'll then propagate them through Phi by scaling and squaring if right is not None: right = spatial.affine_grid(right, fixed_shape) g = regutils.smart_push(g, right, shape=self.shape) h = regutils.smart_push(h, right, shape=self.shape) del right phi_left = spatial.identity_grid(self.shape, **utils.backend(phi)) phi_left += phi if left is not None: phi_left = spatial.affine_matvec(left, phi_left) mugrad = moving.pull_grad(phi_left, rotate=False) del phi_left mugrad = _rotate_grad(mugrad, left, phi) return g, h, mugrad
def draw_curves(shape, s, mode='gaussian', tiny=0, **kwargs): """Draw multiple BSpline curves Parameters ---------- shape : list[int] s : list[BSplineCurve] mode : {'binary', 'gaussian'} Returns ------- x : (*shape) tensor Drawn curve """ s = list(s) x = identity_grid(shape, **utils.backend(s[0].waypoints)) n = len(s) tiny = tiny / n if mode[0].lower() == 'b': s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c = d <= r while s: s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c.bitwise_or_(d <= r) else: s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c = dist_to_prob(d, r, tiny).neg_().add_(1) while s: s1 = s.pop(0) t, d = min_dist(x, s1, **kwargs) r = s1.eval_radius(t) c.mul_(dist_to_prob(d, r, tiny).neg_().add_(1)) c = c.neg_().add_(1) return c
def _get_dat_grid(dat, vx, samp, jitter=True, device='cpu'): """Get sub-sampled image data, and resampling grid. Parameters ---------- dat : (X0, Y0, Z0) tensor_like Fixed image data. vx : (3,) tensor_like. Fixed voxel size. samp : int|float Sub-sampling level. jitter : bool, default=True Add random jittering to identity grid. Returns ---------- dat_samp : (X1, Y1, Z1) tensor_like Sub-sampled fixed image data. grid : (X1, Y1, Z1) tensor_like Sub-sampled image data's resampling grid. """ if isinstance(dat, (list, tuple)): dat = torch.zeros(dat, dtype=torch.float32, device=device) # Modulate samp with voxel size device = dat.device samp = torch.tensor((samp,) * 3).float().to(device) samp = torch.clamp(samp / vx, 1) # Create grid of fixed image, possibly sub-sampled grid = identity_grid(dat.shape, dtype=torch.float32, device=device) if jitter: torch.manual_seed(0) grid += torch.rand_like(grid)*samp # Sub-sampled samp = samp.round().int().tolist() grid = grid[::samp[0], ::samp[1], ::samp[2], ...] dat_samp = dat[::samp[0], ::samp[1], ::samp[2]] return dat_samp, grid
def dist_map(shape, dtype=None, device=None): """Return the squared distance between all pairs in a FOV. Parameters ---------- shape : sequence[int] dtype : optional device : optional Returns ------- dist : (prod(shape), proD(shape) tensor Squared distance map """ backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) g = spatial.identity_grid(shape, **backend) g = g.reshape([-1, dim]) g = (g[:, None, :] - g[None, :, :]).square_().sum(-1) return g
def draw_curve(shape, s, mode='gaussian', tiny=0, **kwargs): """Draw a BSpline curve Parameters ---------- shape : list[int] s : BSplineCurve mode : {'binary', 'gaussian'} Returns ------- x : (*shape) tensor Drawn curve """ x = identity_grid(shape, **utils.backend(s.waypoints)) t, d = min_dist(x, s, **kwargs) r = s.eval_radius(t) if mode[0].lower() == 'b': return d <= r else: return dist_to_prob(d, r, tiny)
def add_identity(cls, disp): dim = disp.shape[-1] shape = disp.shape[-dim-1:-1] return spatial.identity_grid(shape, **utils.backend(disp)).add_(disp)
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def fit_curves_cat(f, s, vx=1, max_iter=8, tol=1e-8, max_levels=4): """Fit the set of curves that maximizes a Categorial likelihood Parameters ---------- f : (*shape) tensor Observed grid of binary labels or smooth probabilities. s : list[BSplineCurve] Initial curves (will be modified in-place) Returns ------- s : list[BSplineCurve] Fitted curves """ TINY = 1e-6 fig = elem = None backend = utils.backend(s[0].coeff) max_iter_position = 8 max_iter_radius = 4 vx = utils.make_vector(vx, f.dim(), **backend) vx0 = vx.clone() n0 = f.numel() # Build pyramid by restriction shapes = [f.shape] images = [f] vxs = [vx] for n_level in range(max_levels - 1): shape = [pymath.ceil(s / 2) for s in shapes[-1]] if all(s == 1 for s in shape): break shapes.append(shape) images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1)) vx = vx * (torch.as_tensor(shapes[-2], **backend) / torch.as_tensor(shapes[-1], **backend)) vxs.append(vx) for s1 in s: s1.restrict(shapes[-2], shapes[-1]) start = time.time() shape = None level = len(images) + 1 while images: level -= 1 print('-' * 16, 'level', level, '-' * 16) if shape is not None: for s1 in s: s1.prolong(shape, shapes[-1]) f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1) x = identity_grid(f.shape, **backend) scl = vx.prod() / vx0.prod() def get_nll(e): ie = (1 - e).log() e = e.log() if f.dtype is torch.bool: ll = e[f].sum(dtype=torch.double) + ie[~f].sum( dtype=torch.double) else: ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum( dtype=torch.double) return -ll nll = float('inf') max_iter_level = max_iter * 2**((level - 1) // 2) for n_iter in range(max_iter_level): nll0_prev = nll for n_curve in range(len(s)): s0 = s[n_curve] s1 = s[:n_curve] + s[n_curve + 1:] ie1 = _draw_curves_inv(f.shape, s1, TINY) for n_iter_position in range(max_iter_position): t, d = min_dist(x, s0) p = s0.eval_position(t).sub_(x) # residuals r = s0.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e0 = dist_to_prob(d, r, TINY) ome0 = 1 - e0 e = 1 - ome0 * ie1 nll_prev = nll nll = get_nll(e) lam = radius_to_prec(r) # gradient of the categorical term g = (1 - f / e) * e0 / ome0 * (-lam) h = (e0 / ome0).square() * (1 - e) / e g = g.unsqueeze(-1) h = h.unsqueeze(-1) lam = lam.unsqueeze(-1) acc = 0.5 h = h * (lam * p).square() if acc != 1: h += (1 - acc) * g.abs() g = g * p # push g = s0.push_position(g, t) h = s0.push_position(h, t) g *= scl h *= scl nll *= scl g.div_(h) s0.coeff -= g wp = [ss.waypoints for ss in s] fig, elem = plot_nll(nll, e, f, wp, fig, elem) print('position', n_iter, n_curve, n_iter_position, nll.item(), (nll_prev - nll).item() / n0) s0.update_waypoints() # if nll_prev - nll < tol * f.numel(): # break if level < 3: max_iter_radius_level = max_iter_radius else: max_iter_radius_level = 0 for n_iter_radius in range(max_iter_radius_level): alpha = (2.355 / 2)**2 t, d = min_dist(x, s0) r = s0.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e0 = dist_to_prob(d, r) ome0 = 1 - e0 e = 1 - ome0 * ie1 d = d.square_() nll_prev = nll nll = get_nll(e) # gradient of the categorical term alpha = alpha * d / r.pow(3) g = (1 - f / e) * e0 / ome0 * alpha h = e0 / ome0.square() acc = 0 h *= alpha.square() if acc != 1: h += (1 - acc) * g.abs() * 3 / r # push g = s0.push_radius(g, t) h = s0.push_radius(h, t) g *= scl h *= scl nll *= scl g.div_(h) s0.coeff_radius -= g s0.coeff_radius.clamp_min_(0.5) wp = [ss.waypoints for ss in s] fig, elem = plot_nll(nll, e, f, wp, fig, elem) print('radius', n_iter, n_curve, n_iter_radius, nll.item(), (nll_prev - nll).item() / n0) s0.update_radius() # if nll_prev - nll < tol * f.numel(): # break if not n_iter % 10: print(n_iter, nll.item(), (nll0_prev - nll).item() / n0) # if abs(nll0_prev - nll) < tol * f.numel(): # print('Converged') # break stop = time.time() print(stop - start)
def fit_curve_cat(f, s, lam=0, gamma=0, vx=1, max_iter=8, tol=1e-8, max_levels=4): """Fit the curve that maximizes the categorical likelihood Parameters ---------- f : (*shape) tensor Observed grid of binary labels or smooth probabilities. s : BSplineCurve Initial curve (will be modified in-place) Other Parameters ---------------- lam : float, default=0 Centerline regularization (bending) gamma : float, default=0 Radius regularization (membrane) vx : float, default=1 Voxel size max_iter : int, default=128 Maximum number of iterations per level (This will me multiplied by 2 at each resolution level, such that more iterations are used at coarser levels). tol : float, default=1e-8 Unused max_levels : int, default=4 Number of multi-resolution levels. Returns ------- s : BSplineCurve Fitted curve """ TINY = 1e-6 fig = elem = None max_iter_position = 8 max_iter_radius = 4 backend = utils.backend(s.coeff) vx = utils.make_vector(vx, f.dim(), **backend) vx0 = vx.clone() n0 = f.numel() # Build pyramid by restriction shapes = [f.shape] images = [f] vxs = [vx] for n_level in range(max_levels - 1): shape = [pymath.ceil(s / 2) for s in shapes[-1]] if all(s == 1 for s in shape): break shapes.append(shape) images.append(restrict(f.unsqueeze(-1), shapes[-1]).squeeze(-1)) s.restrict(shapes[-2], shapes[-1]) vx = vx * (torch.as_tensor(shapes[-2], **backend) / torch.as_tensor(shapes[-1], **backend)) vxs.append(vx) start = time.time() shape = None level = len(images) + 1 while images: level -= 1 print('-' * 16, 'level', level, '-' * 16) if shape is not None: s.prolong(shape, shapes[-1]) f, shape, vx = images.pop(-1), shapes.pop(-1), vxs.pop(-1) scl = vx.prod() / vx0.prod() x = identity_grid(f.shape, **backend) if lam: L = lam * bending3(len(s.coeff), **backend) reg = L.matmul(s.coeff).mul_(vx.square()) reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double) else: reg = 0 if gamma: Lr = gamma * membrane3(len(s.coeff_radius), **backend) Lr /= vx.prod().pow_(1 / len(vx)).square_() reg_radius = Lr.matmul(s.coeff_radius) reg_radius = 0.5 * (s.coeff_radius * reg_radius).sum(dtype=torch.double) else: reg_radius = 0 def get_nll(e): ie = (1 - e).log() e = e.log() if f.dtype is torch.bool: ll = e[f].sum(dtype=torch.double) + ie[~f].sum( dtype=torch.double) else: ll = (e * f).sum(dtype=torch.double) + (ie * (1 - f)).sum( dtype=torch.double) ll = -ll return ll nll = float('inf') max_iter_level = max_iter * 2**((level - 1) // 2) for n_iter in range(max_iter_level): nll0_prev = nll for n_iter_position in range(max_iter_position): t, d = min_dist(x, s) p = s.eval_position(t).sub_(x) # residuals r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r, tiny=TINY) nll_prev = nll nll = get_nll(e) prec = radius_to_prec(r) # gradient of the categorical term omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not() ome = (1 - e) g = (omf / ome - 1) * (-prec) h = omf * e / ome.square() g = g.unsqueeze(-1) h = h.unsqueeze(-1) prec = prec.unsqueeze(-1) acc = 0.5 h = h * (prec * p).square() if acc != 1: h += (1 - acc) * g.abs() g = g * p # push g = s.push_position(g, t) h = s.push_position(h, t) # resolution scale g *= scl h *= scl nll *= scl # regularisation + solve if lam: reg = L.matmul(s.coeff).mul_(vx.square()) g += reg reg = 0.5 * (s.coeff * reg).sum(dtype=torch.double) # h += L[1, :].abs().sum() g = torch.stack([ linalg.lmdiv(h1.diag() + (v1 * v1) * L, g1[:, None])[:, 0] for v1, g1, h1 in zip(vx, g.T, h.T) ], -1) else: g.div_(h) reg = 0 s.coeff.sub_(g) # s.coeff.clamp_min_(0) # for d, sz in enumerate(f.shape): # s.coeff[:, d].clamp_max_(sz-1) fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints, fig, elem) nll = nll + reg + reg_radius print('position', n_iter, n_iter_position, nll.item(), (nll_prev - nll).item() / n0) s.update_waypoints() # if nll_prev - nll < tol * f.numel(): # break if level < 3: max_iter_radius_level = max_iter_radius else: max_iter_radius_level = 0 for n_iter_radius in range(max_iter_radius_level): alpha = (2.355 / 2)**2 t, d = min_dist(x, s) r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r, TINY) d = d.square_() nll_prev = nll nll = get_nll(e) # gradient of the categorical term omf = (1 - f) if f.dtype.is_floating_point else f.bitwise_not() ome = (1 - e) alpha = alpha * d / r.pow(3) g = (omf / ome - 1) * alpha acc = 0 h = omf * e / ome.square() h *= alpha.square() if acc != 1: h += (1 - acc) * g.abs() * 3 / r # push g = s.push_radius(g, t) h = s.push_radius(h, t) # resolution scale g *= scl h *= scl nll *= scl # regularisation + solve if gamma: reg_radius = Lr.matmul(s.coeff_radius) g += reg_radius reg_radius = 0.5 * (s.coeff_radius * reg_radius).sum(dtype=torch.double) g = linalg.lmdiv(h.diag() + L, g[:, None])[:, 0] else: g.div_(h) reg_radius = 0 # solve s.coeff_radius -= g s.coeff_radius.clamp_min_(0.5) fig, elem = plot_nll([nll, reg, reg_radius], e, f, s.waypoints, fig, elem) nll = nll + reg + reg_radius print('radius', n_iter, n_iter_radius, nll.item(), (nll_prev - nll).item() / n0) s.update_radius() # if nll_prev - nll < tol * f.numel(): # break if not n_iter % 10: print(n_iter, nll.item(), (nll0_prev - nll).item() / n0) # if nll0_prev - nll < tol * f.numel(): # print('Converged') # break stop = time.time() print(stop - start)
def fit_curve_joint(f, s, max_iter=128, tol=1e-8): """Fit the curve that maximizes the joint probability p(f) * p(s) Parameters ---------- f : (*shape) tensor Observed grid of binary labels or smooth probabilities. s : BSplineCurve Initial curve (will be modified in-place) max_iter : int, default=128 tol : float, default=1e-8 Returns ------- s : BSplineCurve Fitted curve """ x = identity_grid(f.shape, **utils.backend(s.coeff)) max_iter_position = 10 max_iter_radius = 3 sumf = f.sum(dtype=torch.double) def get_nll(e): if f.dtype is torch.bool: return sumf + e.sum( dtype=torch.double) - 2 * e[f].sum(dtype=torch.double) else: return sumf + e.sum( dtype=torch.double) - 2 * (e * f).sum(dtype=torch.double) start = time.time() nll = float('inf') for n_iter in range(max_iter): nll0_prev = nll for n_iter_position in range(max_iter_position): t, d = min_dist(x, s) p = s.eval_position(t).sub_(x) # residuals r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r) nll_prev = nll nll = get_nll(e) lam = radius_to_prec(r) # gradient of the categorical term g = e * (1 - 2 * f) * (-lam) g = g.unsqueeze(-1) lam = lam.unsqueeze(-1) # e = e.unsqueeze(-1) # h = g.abs() + e * (lam * p).square() h = g.abs() * (1 + lam * p.square()) g = g * p # push g = s.push_position(g, t) h = s.push_position(h, t) g.div_(h) s.coeff -= g # print('position', n_iter, n_iter_position, # nll.item(), (nll_prev - nll).item() / f.numel()) if nll_prev - nll < tol * f.numel(): break for n_iter_position in range(max_iter_radius): alpha = (2.355 / 2)**2 t, d = min_dist(x, s) r = s.eval_radius(t) r = torch.as_tensor(r, **utils.backend(x)) e = dist_to_prob(d, r) d = d.square_() nll_prev = nll nll = get_nll(e) # gradient of the categorical term g = e * (1 - 2 * f) * (alpha * d / r.pow(3)) h = g.abs() * (alpha * d / r.pow(3)) * (1 + 3 / r) # push g = s.push_radius(g, t) h = s.push_radius(h, t) g.div_(h) s.coeff_radius -= g s.coeff_radius.clamp_min_(0.5) # print('radius', n_iter, n_iter_position, # nll.item(), (nll_prev - nll).item() / f.numel()) if nll_prev - nll < tol * f.numel(): break if not n_iter % 10: print(n_iter, nll.item(), (nll0_prev - nll).item() / f.numel()) if abs(nll0_prev - nll) < tol * f.numel(): print('Converged') break stop = time.time() print(stop - start)
def forward(self, grid, **overload): """ Parameters ---------- grid : (N, *spatial, dim) Displacement grid overload : dict Returns ------- aff : (N, dim+1, dim+1) Affine matrix that is closest to grid in the least square sense """ shift = overload.get('shift', self.shift) grid = torch.as_tensor(grid) info = dict(dtype=grid.dtype, device=grid.device) nb_dim = grid.shape[-1] shape = grid.shape[1:-1] if shift: affine_shift = torch.cat((torch.eye( nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine_shift = spatial.as_euclidean(affine_shift) # the forward model is: # phi(x) = M\A*M*x # where phi is a *transformation* field, M is the shift matrix # and A is the affine matrix. # We can decompose phi(x) = x + d(x), where d is a *displacement* # field, yielding: # d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D` # and `G`, we have: # D = G*B' # Therefore, the least squares B is obtained as: # B' = inv(G'*G) * (G'*D) # Then, A is # A = M*(B + I)/M # # Finally, we project the affine matrix to its tangent space: # prm[k] = <log(A), B[k]> # were <X,Y> = trace(X'*Y) is the Frobenius inner product. def igg(identity): # Compute inv(g*g'), where g has homogeneous coordinates. # Instead of appending ones, we compute each element of # the block matrix ourselves: # [[g'*g, g'*1], # [1'*g, 1'*1]] # where 1'*1 = N, the number of voxels. g = identity.reshape([identity.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) gg = torch.matmul(g.transpose(-1, -2), g) gg = torch.cat((gg, sumg), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) gg = torch.cat((gg, sumg), dim=2) return gg.inverse() def gd(identity, disp): # compute g'*d, where g and d have homogeneous coordinates. # [[g'*d, g'*1], # [1'*d, 1'*1]] g = identity.reshape([identity.shape[0], -1, nb_dim]) d = disp.reshape([disp.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) sumd = d.sum(dim=1, keepdim=True) gd = torch.matmul(g.transpose(-1, -2), d) gd = torch.cat((gd, sumd), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]]) gd = torch.cat((gd, sumg), dim=2) return gd def eye(d): x = torch.eye(d, **info) z = x.new_zeros([1, d], **info) x = torch.cat((x, z), dim=0) z = x.new_zeros([d + 1, 1], **info) x = torch.cat((x, z), dim=1) return x identity = spatial.identity_grid(shape, **info)[None, ...] affine = torch.matmul(igg(identity), gd(identity, grid)) affine = affine.transpose(-1, -2) + eye(nb_dim) affine = affine[..., :-1, :] if shift: affine = spatial.as_euclidean(affine) affine = spatial.affine_matmul(affine_shift, affine) affine = spatial.as_euclidean(affine) affine = spatial.affine_rmdiv(affine, affine_shift) affine = spatial.affine_make_square(affine) return affine