def regulariser(self, v=None): if v is None: v = self.dat return spatial.regulariser_grid(v, **self.penalty, factor=self.factor / py.prod(self.shape), voxel_size=self.voxel_size)
def __call__(self, vel, grad=False): # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: vel.requires_grad_() if vel.grad is not None: vel.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(vel, grad) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(vel, grad) dim = vel.shape[-1] in_line_search = not grad do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % 20 == 0 # forward if self.id is None: self.id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) grid = self.id + vel warped = spatial.grid_pull(self.moving, grid, bound='dct2', extrapolate=True) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # log-likelihood in observed space llx = self.loss.loss(warped, self.fixed) del warped # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm) llv = 0.5 * (vel * vgrad).sum() lll = llx + llv del vgrad # print objective llx = llx.item() llv = llv.item() ll = lll.item() if not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') out = [lll] if grad: lll.backward() out.append(vel.grad) vel.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, vel, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ vel : (..., *spatial, dim) tensor, Displacement grad : Whether to compute and return the gradient wrt `vel` hess : Whether to compute and return the Hessian wrt `vel` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., *spatial, dim) tensor, optional, Gradient wrt velocity h : (..., *spatial, ?) tensor, optional, Hessian wrt velocity gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. dim = vel.shape[-1] pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) # in_line_search = not grad and not hess logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # forward if self.id is None: self.id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) grid = self.id + vel warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess and not hessmov: llx = self.loss.loss(warped, self.fixed) elif not hess and not hessmov: llx, grad = self.loss.loss_grad(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) if grad is not False: grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=False) llv = 0.5 * (vel * vgrad).sum() if grad is not False: grad += vgrad del vgrad # print objective llx = llx.item() llv = llv.item() ll = llx + llv if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def __call__(self, vel, grad=False, hess=False, gradmov=False, hessmov=False): # This loop performs the forward pass, and computes # derivatives along the way. dim = vel.shape[-1] pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) in_line_search = not grad and not hess logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # forward if self.kernel is None: self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm, **utils.backend(vel)) grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm) warped = spatial.grid_pull(self.moving, grid, bound='dct2', extrapolate=True) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, self.fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients if grad is not False or hess is not False: if self.mugrad is None: self.mugrad = spatial.diff(self.moving, dim=list(range(-dim, 0)), bound='dct2') if grad is not False: grad = grad.neg_() # "final inverse" to "initial" grad = spatial.grid_push(grad, grid) grad = jg(self.mugrad, grad) if hess is not False: hess = spatial.grid_push(hess, grid) hess = jhj(self.mugrad, hess) # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True) llv = 0.5 * (vel * vgrad).sum() if grad is not False: grad += vgrad del vgrad # print objective llx = llx.item() llv = llv.item() ll = llx + llv if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]