def lddmm_matching(I, J, m=None, lddmm_steps=1000, lddmm_integration_steps=10, reg_weight=1e-1, learning_rate_pose=2e-2, fluid_params=[1.0, .1, .01], progress_bar=True): """Matching image I to J via LDDMM""" if m is None: defsh = [I.shape[0], 3] + list(I.shape[2:]) m = torch.zeros(defsh, dtype=I.dtype).to(I.device) do_regridding = m.shape[2:] != I.shape[2:] J = J.to(I.device) matchterms = [] regterms = [] losses = [] metric = lm.FluidMetric(fluid_params) m.requires_grad_() pb = range(lddmm_steps) if progress_bar: pb = tqdm(pb) for mit in pb: if m.grad is not None: m.grad.detach_() m.grad.zero_() m.requires_grad_() h = lm.expmap(metric, m, num_steps=lddmm_integration_steps) if do_regridding is not None: h = lm.regrid(h, shape=I.shape[2:], displacement=True) Idef = lm.interp(I, h) regterm = (metric.sharp(m) * m).mean() matchterm = mse_loss(Idef, J) matchterms.append(matchterm.detach().item()) regterms.append(regterm.detach().item()) loss = matchterm + reg_weight * regterm loss.backward() loss.detach_() with torch.no_grad(): #v = metric.sharp(m) #regterm = (v*m).mean()#.detach() #del v #losses.append(loss.detach()+ .5*reg_weight*regterm) losses.append(loss.detach()) p = metric.flat(m.grad).detach() if torch.isnan(losses[-1]).item(): print(f"loss is NaN at iter {mit}") break #if mit > 0 and losses[-1].item() > losses[-2].item(): # print(f"loss increased at iter {mit}") #p.add_(reg_weight/np.prod(m.shape[1:]), m) m.add_(-learning_rate_pose, p) return m.detach(), [l.item() for l in losses], matchterms, regterms
I.requires_grad_(False) J.requires_grad_(False) #fluid_params=[1e-2,.0,.01] fluid_params = [5e-2, .0, .01] diffeo_scale = None mmatch, losses_match = lddmm_matching(I, J, fluid_params=fluid_params, diffeo_scale=diffeo_scale) if args.plot: sl = I.shape[3] // 2 metric = lm.FluidMetric(fluid_params) hsmall = lm.expmap(metric, mmatch, num_steps=10) h = hsmall if diffeo_scale is not None: h = lm.regrid(h, shape=I.shape[2:], displacement=True) Idef = lm.interp(I, h) plt.plot(losses_match, '-o') plt.figure() plt.subplot(1, 3, 1) plt.imshow(I[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.subplot(1, 3, 2) plt.imshow(Idef[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.subplot(1, 3, 3) plt.imshow(J[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.figure() if diffeo_scale is None: ds = 1 else: ds = diffeo_scale lm.quiver(mmatch[:, :2, :, sl // ds, :] / fluid_params[2])
def test_regrid_displacement_gradcheck(bs, dim): imsh = tuple([bs, dim] + [res] * dim) I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda() outshape = [res + 1] * dim foo = lambda J: lm.regrid(J, shape=outshape, displacement=True) catch_gradcheck("Failed regrid displacement gradcheck", foo, (I, ))
def test_regrid_identity(bs, dim, disp): imsh = tuple([bs, dim] + [res] * dim) I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda() outshape = imsh[2:] Ir = lm.regrid(I, shape=outshape, displacement=disp) assert torch.allclose(I, Ir), "Failed regrid identity check"