def _deform_1d(img, disp, grad=False): img = utils.movedim(img, 0, -2) disp = disp.unsqueeze(-1) disp = spatial.add_identity_grid(disp) wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True) wrp = utils.movedim(wrp, -2, 0) if not grad: return wrp, None grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True) grd = utils.movedim(grd.squeeze(-1), -2, 0) return wrp, grd
def add_identity(self, disp): disp = utils.movedim(disp, self.displacement_dim, -1) disp = spatial.add_identity_grid(disp.unsqueeze(-1)).squeeze(-1) disp = utils.movedim(disp, -1, self.displacement_dim) return disp
def add_identity_1d(grid): """Add 1D identity""" return add_identity_grid(grid[..., None])[..., 0]
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def do_vel(self, vel, grad=False, hess=False, in_line_search=False): """Forward pass for updating the nonlinear component""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== if self.affine: aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False) aff_pos = self.affine.position[0].lower() else: aff_pos = 'x' aff0 = iaff0 = torch.eye(self.nonlin.dim + 1) vel0 = vel if any(loss.backward for loss in self.losses): phi0, iphi0 = self.nonlin.exp2(vel0, recompute=True, cache_result=not in_line_search) ivel0 = -vel0 else: phi0 = self.nonlin.exp(vel0, recompute=True, cache_result=not in_line_search) iphi0 = ivel0 = None aff0 = aff0.to(phi0) iaff0 = iaff0.to(phi0) # ============================================================== # ACCUMULATE DERIVATIVES # ============================================================== has_printed = False for loss in self.losses: # ========================================================== # ONE LOSS COMPONENT # ========================================================== moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, vel00 = iphi0, iaff0, ivel0 else: phi00, aff00, vel00 = phi0, aff0, vel0 # ---------------------------------------------------------- # build left and right affine # ---------------------------------------------------------- aff_right = fixed.affine if aff_pos in 'fs': # affine position: fixed or symmetric aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left = self.nonlin.affine if aff_pos in 'ms': # affine position: moving or symmetric aff_left = aff00 @ self.nonlin.affine aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: aff_right = None phi = spatial.add_identity_grid(phi00) disp = phi00 else: phi = spatial.affine_grid(aff_right, fixed.shape) disp = regutils.smart_pull_grid(phi00, phi) phi += disp if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: aff_left = None else: phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim, title=f'(nonlin) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt phi # ---------------------------------------------------------- if grad or hess: g, h, mugrad = self.nonlin.propagate_grad( g, h, moving, phi00, aff_left, aff_right, inv=loss.backward) g = regutils.jg(mugrad, g) h = regutils.jhj(mugrad, h) if isinstance(self.nonlin, SVFModel): # propagate backward by scaling and squaring g, h = spatial.exp_backward(vel00, g, h, steps=self.nonlin.steps) sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # ============================================================== # REGULARIZATION # ============================================================== vgrad = self.nonlin.regulariser(vel0) llv = 0.5 * vel0.flatten().dot(vgrad.flatten()) if grad: sumgrad += vgrad del vgrad # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += llv sumloss += self.lla self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): llv = llv.item() ll = sumloss.item() lla = self.lla if in_line_search: line = '(search) | ' else: line = '(nonlin) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.llv = llv self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]