def greens(self, shape, **backend): opt = dict(absolute=self.absolute, membrane=self.membrane, bending=self.bending, lame=self.lame, factor=self.factor) kernel = spatial.greens(shape, **opt, **backend) return kernel
def set_kernel(self, kernel=None): if kernel is None: kernel = spatial.greens(self.shape, **self.penalty, factor=self.factor / py.prod(self.shape), voxel_size=self.voxel_size, **utils.backend(self.dat)) self.kernel = kernel return self
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.dim1d is None: shape = self.spatial_shape vx = self.voxel_size else: shape = [self.spatial_shape[self.dim1d]] vx = self.voxel_size[self.dim1d] self.kernel = spatial.greens(shape, **self.reg_prm, factor=self.factor, voxel_size=vx, **utils.backend(self))
def register(fixed=None, moving=None, dim=None, lam=1., loss='mse', optim='nesterov', hilbert=None, max_iter=500, sub_iter=16, lr=None, ls=0, plot=False, klosure=RegisterStep, kernel=None, **prm): """Nonlinear registration between two images using smooth displacements. Parameters ---------- fixed : (..., K, *spatial) tensor Fixed image moving : (..., K, *spatial) tensor Moving image dim : int, default=`fixed.dim() - 1` Number of spatial dimensions lam : float, default=1 Modulate regularisation loss : {'mse', 'cat'} or OptimizationLoss, default='mse' 'mse': Mean-squared error 'cat': Categorical cross-entropy optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='relax' 'relax' : Gauss-Newton (linear system solved by relaxation) 'cg' : Gauss-Newton (linear system solved by conjugate gradient) 'gd' : Gradient descent 'momentum' : Gradient descent with momentum 'nesterov' : Nesterov-accelerated gradient descent 'lbfgs' : Limited-memory BFGS hilbert : bool, default=True Use hilbert preconditioning (not used if optim is second order) max_iter : int, default=100 Maximum number of Gauss-Newton or Gradient descent iterations sub_iter : int, default=16 Number of relax/cg iterations per GN step lr : float, default=1 Learning rate. ls : int, default=0 Number of line search iterations. absolute : float, default=1e-4 Penalty on absolute displacements membrane : float, default=1e-3 Penalty on first derivatives bending : float, default=0.2 Penalty on second derivatives lame : (float, float), default=(0.05, 0.2) Penalty on zooms and shears Returns ------- disp : (..., *spatial, dim) tensor Displacement field. """ defaults_velocity(prm) # If no inputs provided: demo "circle to square" if fixed is None or moving is None: fixed, moving = phantoms.demo_register(cat=(loss == 'cat')) # init tensors fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) shape = fixed.shape[-dim:] lam = lam / py.prod(shape) prm['factor'] = lam velshape = [*fixed.shape[:-dim - 1], *shape, dim] vel = torch.zeros(velshape, **utils.backend(fixed)) # init optimizer optim = regutils.make_iteroptim_grid(optim, lr, ls, max_iter, sub_iter, **prm) if hilbert is None: hilbert = not optim.requires_hess if hilbert and kernel is None: kernel = spatial.greens(shape, **prm, **utils.backend(fixed)) if kernel is not None: optim.preconditioner = lambda x: spatial.greens_apply(x, kernel) # init loss loss = losses.make_loss(loss, dim) print( f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}' ) print('-' * 63) closure = klosure(moving, fixed, loss, plot=plot, max_iter=optim.max_iter, **prm) vel = optim.iter(vel, closure) print('') return vel
def __call__(self, vel, grad=False, hess=False, gradmov=False, hessmov=False): # This loop performs the forward pass, and computes # derivatives along the way. dim = vel.shape[-1] pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) in_line_search = not grad and not hess logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # forward if self.kernel is None: self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm, **utils.backend(vel)) grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm) warped = spatial.grid_pull(self.moving, grid, bound='dct2', extrapolate=True) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, self.fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients if grad is not False or hess is not False: if self.mugrad is None: self.mugrad = spatial.diff(self.moving, dim=list(range(-dim, 0)), bound='dct2') if grad is not False: grad = grad.neg_() # "final inverse" to "initial" grad = spatial.grid_push(grad, grid) grad = jg(self.mugrad, grad) if hess is not False: hess = spatial.grid_push(hess, grid) hess = jhj(self.mugrad, hess) # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True) llv = 0.5 * (vel * vgrad).sum() if grad is not False: grad += vgrad del vgrad # print objective llx = llx.item() llv = llv.item() ll = llx + llv if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, vel, grad=False): # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: vel.requires_grad_() if vel.grad is not None: vel.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(vel, grad) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(vel, grad) dim = vel.shape[-1] nvox = py.prod(vel.shape[-dim - 1:-1]) in_line_search = not grad do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % 20 == 0 # forward if self.kernel is None: self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm, **utils.backend(vel)) grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm) warped = spatial.grid_pull(self.moving, grid, bound='dct2', extrapolate=True) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # log-likelihood in observed space llx = self.loss.loss(warped, self.fixed) del warped # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm).div_(nvox) llv = 0.5 * (vel * vgrad).sum() lll = llx + llv del vgrad # print objective llx = llx.item() llv = llv.item() ll = lll.item() if not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') out = [lll] if grad: lll.backward() out.append(vel.grad) vel.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) voxel_size = overload.get('voxel_size', self.voxel_size) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) voxel_size = voxel_size.tolist() lame = py.make_list(self.lame, 2) if (hasattr(self, '_greens') and self._voxel_size == voxel_size and self._shape == shape): greens = self._greens.to(dtype=dtype, device=device) else: greens = spatial.greens( shape, absolute=self.absolute, membrane=self.membrane, bending=self.bending, lame=self.lame, voxel_size=voxel_size, device=device, dtype=dtype) if any(lame): greens, scale, _ = torch.svd(greens) scale = scale.sqrt_() greens *= scale.unsqueeze(-1) else: greens = greens.sqrt_() if self.cache_greens: self._greens = greens self._voxel_size = voxel_size self._shape = shape sample = torch.randn([2, batch, *shape, nb_dim], **backend) # multiply by square root of greens if greens.dim() > nb_dim: # lame sample = linalg.matvec(greens, sample) else: sample = sample * greens.unsqueeze(-1) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) sample = sample / voxel_size.sqrt() sample = fft.complex(sample[0], sample[1]) # inverse Fourier transform dims = list(range(-nb_dim-1, -1)) sample = fft.real(fft.ifftn(sample, dim=dims)) sample *= py.prod(shape) # add mean sample += mean return sample
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional channel : int, optional voxel_size : float or (dim,) vector_like, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) channel = overload.get('channel', self.channel) voxel_size = overload.get('voxel_size', self.voxel_size) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) voxel_size = voxel_size.tolist() if (hasattr(self, '_greens') and self._voxel_size == voxel_size and self._channel == channel and self._shape == shape): greens = self._greens.to(dtype=dtype, device=device) else: mean = utils.make_vector(self.mean, channel, **backend) absolute = utils.make_vector(self.absolute, channel, **backend) membrane = utils.make_vector(self.membrane, channel, **backend) bending = utils.make_vector(self.bending, channel, **backend) greens = [] for c in range(channel): greens.append(spatial.greens( shape, absolute=absolute[c], membrane=membrane[c], bending=bending[c], lame=0, voxel_size=voxel_size, device=device, dtype=dtype)) greens = torch.stack(greens) greens = greens.sqrt_() if self.cache_greens: self._greens = greens self._voxel_size = voxel_size self._shape = shape # sample white noise sample = torch.randn([2, batch, channel, *shape], **backend) sample *= greens.unsqueeze(-1) sample = fft.complex(sample[0], sample[1]) # inverse Fourier transform dims = list(range(-nb_dim, 0)) sample = fft.real(fft.ifftn(sample, dim=dims)) sample *= py.prod(shape) # add mean sample += utils.unsqueeze(mean, -1, len(shape)) return sample
def forward(self, velocity, fwd=None, inv=None, voxel_size=None): """ Parameters ---------- velocity :(batch, *spatial, dim) tensor Initial velocity field. fwd : bool, default=self.fwd inv : bool, default=self.inv voxel_size : sequence[float], default=self.voxel_size Returns ------- forward : (batch, *spatial, dim) tensor, if `forward is True` Forward displacement (if `displacement is True`) or transformation (if `displacement is False`) field. inverse : (batch, *spatial, dim) tensor, if `inverse is True` Inverse displacement (if `displacement is True`) or transformation (if `displacement is False`) field. """ fwd = fwd if fwd is not None else self.fwd inv = inv if inv is not None else self.inv voxel_size = voxel_size if voxel_size is not None else self.voxel_size shoot_opt = { 'steps': self.steps, 'displacement': self.displacement, 'voxel_size': voxel_size, 'absolute': self.absolute, 'membrane': self.membrane, 'bending': self.bending, 'lame': self.lame, 'factor': self.factor, } greens_prm = { 'absolute': self.absolute, 'membrane': self.membrane, 'bending': self.bending, 'lame': self.lame, 'factor': self.factor, 'voxel_size': voxel_size, 'shape': velocity.shape[1:-1], } if self.cache_greens: if getattr(self, '_greens_prm', None) == greens_prm: greens = self._greens.to(**utils.backend(velocity)) else: greens = spatial.greens(**greens_prm, **utils.backend(velocity)) self._greens = greens self._greens_prm = greens_prm else: greens = spatial.greens(**greens_prm, **utils.backend(velocity)) shoot_fn = spatial.shoot_approx if self.approx else spatial.shoot output = [] if inv: y, iy = shoot_fn(velocity, greens, return_inverse=True, **shoot_opt) if fwd: output.append(y) output.append(iy) elif fwd: y = shoot_fn(velocity, greens, **shoot_opt) output.append(y) return output if len(output) > 1 else \ output[0] if len(output) == 1 else \ None