def lddmm_matching(I,
                   J,
                   m=None,
                   lddmm_steps=1000,
                   lddmm_integration_steps=10,
                   reg_weight=1e-1,
                   learning_rate_pose=2e-2,
                   fluid_params=[1.0, .1, .01],
                   progress_bar=True):
    """Matching image I to J via LDDMM"""
    if m is None:
        defsh = [I.shape[0], 3] + list(I.shape[2:])
        m = torch.zeros(defsh, dtype=I.dtype).to(I.device)
    do_regridding = m.shape[2:] != I.shape[2:]
    J = J.to(I.device)
    matchterms = []
    regterms = []
    losses = []
    metric = lm.FluidMetric(fluid_params)
    m.requires_grad_()
    pb = range(lddmm_steps)
    if progress_bar: pb = tqdm(pb)
    for mit in pb:
        if m.grad is not None:
            m.grad.detach_()
            m.grad.zero_()
        m.requires_grad_()
        h = lm.expmap(metric, m, num_steps=lddmm_integration_steps)
        if do_regridding is not None:
            h = lm.regrid(h, shape=I.shape[2:], displacement=True)
        Idef = lm.interp(I, h)
        regterm = (metric.sharp(m) * m).mean()
        matchterm = mse_loss(Idef, J)
        matchterms.append(matchterm.detach().item())
        regterms.append(regterm.detach().item())
        loss = matchterm + reg_weight * regterm
        loss.backward()
        loss.detach_()
        with torch.no_grad():
            #v = metric.sharp(m)
            #regterm = (v*m).mean()#.detach()
            #del v
            #losses.append(loss.detach()+ .5*reg_weight*regterm)
            losses.append(loss.detach())
            p = metric.flat(m.grad).detach()
            if torch.isnan(losses[-1]).item():
                print(f"loss is NaN at iter {mit}")
                break
            #if mit > 0 and losses[-1].item() > losses[-2].item():
            #    print(f"loss increased at iter {mit}")
            #p.add_(reg_weight/np.prod(m.shape[1:]), m)
            m.add_(-learning_rate_pose, p)
    return m.detach(), [l.item() for l in losses], matchterms, regterms
 I.requires_grad_(False)
 J.requires_grad_(False)
 #fluid_params=[1e-2,.0,.01]
 fluid_params = [5e-2, .0, .01]
 diffeo_scale = None
 mmatch, losses_match = lddmm_matching(I,
                                       J,
                                       fluid_params=fluid_params,
                                       diffeo_scale=diffeo_scale)
 if args.plot:
     sl = I.shape[3] // 2
     metric = lm.FluidMetric(fluid_params)
     hsmall = lm.expmap(metric, mmatch, num_steps=10)
     h = hsmall
     if diffeo_scale is not None:
         h = lm.regrid(h, shape=I.shape[2:], displacement=True)
     Idef = lm.interp(I, h)
     plt.plot(losses_match, '-o')
     plt.figure()
     plt.subplot(1, 3, 1)
     plt.imshow(I[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.subplot(1, 3, 2)
     plt.imshow(Idef[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.subplot(1, 3, 3)
     plt.imshow(J[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.figure()
     if diffeo_scale is None:
         ds = 1
     else:
         ds = diffeo_scale
     lm.quiver(mmatch[:, :2, :, sl // ds, :] / fluid_params[2])
Beispiel #3
0
def test_regrid_displacement_gradcheck(bs, dim):
    imsh = tuple([bs, dim] + [res] * dim)
    I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda()
    outshape = [res + 1] * dim
    foo = lambda J: lm.regrid(J, shape=outshape, displacement=True)
    catch_gradcheck("Failed regrid displacement gradcheck", foo, (I, ))
Beispiel #4
0
def test_regrid_identity(bs, dim, disp):
    imsh = tuple([bs, dim] + [res] * dim)
    I = torch.randn(imsh, dtype=torch.float64, requires_grad=True).cuda()
    outshape = imsh[2:]
    Ir = lm.regrid(I, shape=outshape, displacement=disp)
    assert torch.allclose(I, Ir), "Failed regrid identity check"