Esempio n. 1
0
def test_fluid_flat_gradcheck(bs, dim):
    fluid_params = [.1, .01, .001]
    defsh = tuple([bs, dim] + [res] * dim)
    v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda()
    metric = lm.FluidMetric(fluid_params)
    catch_gradcheck(
        f"Failed fluid flat gradcheck with batch size {bs} dim {dim}",
        metric.flat, (v, ))
Esempio n. 2
0
def test_fluid_inverse(bs, dim):
    fluid_params = [.1, .01, .001]
    defsh = tuple([bs, dim] + [res] * dim)
    m = torch.randn(defsh, dtype=torch.float64, requires_grad=False).cuda()
    metric = lm.FluidMetric(fluid_params)
    v = metric.sharp(m)
    vm = metric.flat(v)
    assert torch.allclose(
        vm, m, atol=1e-3
    ), f"Failed fluid inverse check with batch size {bs} dim {dim}"
def lddmm_matching(I,
                   J,
                   m=None,
                   lddmm_steps=1000,
                   lddmm_integration_steps=10,
                   reg_weight=1e-1,
                   learning_rate_pose=2e-2,
                   fluid_params=[1.0, .1, .01],
                   progress_bar=True):
    """Matching image I to J via LDDMM"""
    if m is None:
        defsh = [I.shape[0], 3] + list(I.shape[2:])
        m = torch.zeros(defsh, dtype=I.dtype).to(I.device)
    do_regridding = m.shape[2:] != I.shape[2:]
    J = J.to(I.device)
    matchterms = []
    regterms = []
    losses = []
    metric = lm.FluidMetric(fluid_params)
    m.requires_grad_()
    pb = range(lddmm_steps)
    if progress_bar: pb = tqdm(pb)
    for mit in pb:
        if m.grad is not None:
            m.grad.detach_()
            m.grad.zero_()
        m.requires_grad_()
        h = lm.expmap(metric, m, num_steps=lddmm_integration_steps)
        if do_regridding is not None:
            h = lm.regrid(h, shape=I.shape[2:], displacement=True)
        Idef = lm.interp(I, h)
        regterm = (metric.sharp(m) * m).mean()
        matchterm = mse_loss(Idef, J)
        matchterms.append(matchterm.detach().item())
        regterms.append(regterm.detach().item())
        loss = matchterm + reg_weight * regterm
        loss.backward()
        loss.detach_()
        with torch.no_grad():
            #v = metric.sharp(m)
            #regterm = (v*m).mean()#.detach()
            #del v
            #losses.append(loss.detach()+ .5*reg_weight*regterm)
            losses.append(loss.detach())
            p = metric.flat(m.grad).detach()
            if torch.isnan(losses[-1]).item():
                print(f"loss is NaN at iter {mit}")
                break
            #if mit > 0 and losses[-1].item() > losses[-2].item():
            #    print(f"loss increased at iter {mit}")
            #p.add_(reg_weight/np.prod(m.shape[1:]), m)
            m.add_(-learning_rate_pose, p)
    return m.detach(), [l.item() for l in losses], matchterms, regterms
 _, I = oasis_ds_std[0]
 _, J = oasis_ds_std[10]
 I = I.unsqueeze(1).to('cuda')
 J = J.unsqueeze(1).to(I.device)
 I.requires_grad_(False)
 J.requires_grad_(False)
 #fluid_params=[1e-2,.0,.01]
 fluid_params = [5e-2, .0, .01]
 diffeo_scale = None
 mmatch, losses_match = lddmm_matching(I,
                                       J,
                                       fluid_params=fluid_params,
                                       diffeo_scale=diffeo_scale)
 if args.plot:
     sl = I.shape[3] // 2
     metric = lm.FluidMetric(fluid_params)
     hsmall = lm.expmap(metric, mmatch, num_steps=10)
     h = hsmall
     if diffeo_scale is not None:
         h = lm.regrid(h, shape=I.shape[2:], displacement=True)
     Idef = lm.interp(I, h)
     plt.plot(losses_match, '-o')
     plt.figure()
     plt.subplot(1, 3, 1)
     plt.imshow(I[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.subplot(1, 3, 2)
     plt.imshow(Idef[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.subplot(1, 3, 3)
     plt.imshow(J[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray')
     plt.figure()
     if diffeo_scale is None:
Esempio n. 5
0
def test_expmap_zero(bs, dim, step, params):
    defsh = tuple([bs, dim] + [res] * dim)
    m = torch.zeros(defsh, dtype=torch.float64, requires_grad=False).cuda()
    metric = lm.FluidMetric(params)
    h = lm.expmap(metric, m, num_steps=step)
    assert torch.allclose(m, h), "Failed expmap of zero is identity check"
def lddmm_atlas(dataset,
        I0=None,
        num_epochs=500,
        batch_size=10,
        lddmm_steps=1,
        lddmm_integration_steps=5,
        reg_weight=1e2,
        learning_rate_pose = 2e2,
        learning_rate_image = 1e4,
        fluid_params=[0.1,0.,.01],
        device='cuda',
        momentum_preconditioning=True,
        momentum_pattern='oasis_momenta/momentum_{}.pth',
        gpu=None,
        world_size=1,
        rank=0):
    if world_size > 1:
        sampler = DistributedSampler(dataset, 
                num_replicas=world_size,
                rank=rank)
    else:
        sampler = None
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size,
            num_workers=8, pin_memory=True, shuffle=False)
    if I0 is None:
        # initialize base image to mean
        I0 = batch_average(dataloader, dim=0)
    I0 = I0.view(1,1,*I0.squeeze().shape)
    #I = DenseInterp(I0)
    #if gpu is not None:
        #I = DistributedDataParallel(I, device_ids=[gpu], output_device=gpu)
        #I = I.to(f'cuda:{gpu}')
    #else:
    I = I0.clone()
    I = I.to(device)
    #image_optimizer = torch.optim.SGD(I.parameters(),
    image_optimizer = torch.optim.SGD([I],
                                      lr=learning_rate_image,
                                      weight_decay=0)
    metric = lm.FluidMetric(fluid_params)
    losses = []
    reg_terms = []
    iter_losses = []
    epbar = range(num_epochs)
    if rank == 0:
        epbar = tqdm(epbar, desc='epoch')
    ms = torch.zeros(len(dataset),3,*I0.shape[-3:], dtype=I0.dtype).pin_memory()
    for epoch in epbar:
        epoch_loss = 0.0
        epoch_reg_term = 0.0
        itbar = dataloader
        I.requires_grad_(True)
        image_optimizer.zero_grad()
        for it, (ix, img) in enumerate(itbar):
            m = ms[ix,...].detach()
            m = m.to(device)
            img = img.to(device)
            for lit in range(lddmm_steps):
                # compute image gradient in last step
                I.requires_grad_(lit == lddmm_steps - 1)
                # enables taking multiple LDDMM step per image update
                m.requires_grad_(True)
                if m.grad is not None:
                    m.grad.detach_()
                    m.grad.zero_()
                h = lm.expmap(metric, m, num_steps=lddmm_integration_steps)
                #Idef = I(h)
                Idef = lm.interp(I, h)
                v = metric.sharp(m)
                regterm = reg_weight*(v*m).sum()
                loss = (mse_loss(Idef, img, reduction='sum') + regterm) \
                        / (img.numel())
                loss.backward()
                # this makes it so that we can reduce the loss and eventually get
                # an accurate MSE for the entire dataset
                with torch.no_grad():
                    li = (loss*(img.shape[0]/len(dataloader.dataset))).detach()
                    p = m.grad
                    if momentum_preconditioning:
                        p = metric.flat(p)
                    m.add_(-learning_rate_pose, p)
                    if world_size > 1:
                        all_reduce(li)
                    iter_losses.append(li.item())
                    m = m.detach()
                    del p
            with torch.no_grad():
                epoch_loss += li
                ri = (regterm*(img.shape[0]/(img.numel()*len(dataloader.dataset)))).detach()
                epoch_reg_term += ri
                ms[ix,...] = m.detach().cpu()
            del m, h, Idef, v, loss, regterm, img
        with torch.no_grad():
            if world_size > 1:
                all_reduce(epoch_loss)
                all_reduce(epoch_reg_term)
                all_reduce(I.grad)
                I.grad = I.grad/world_size
            # average over iterations
            I.grad = I.grad / len(dataloader)
        image_optimizer.step()
        losses.append(epoch_loss.item())
        reg_terms.append(epoch_reg_term.item())
        if rank == 0:
            epbar.set_postfix(epoch_loss=epoch_loss.item(),
                    epoch_reg=epoch_reg_term.item())
    #return I.state_dict()['I'].detach(), ms.detach(), losses, iter_losses
    return I.detach(), ms.detach(), losses, iter_losses
def deep_lddmm_atlas(
        dataset,
        I0=None,
        fluid_params=[1e-1, 0., .01],
        num_epochs=500,
        batch_size=2,
        reg_weight=.001,
        dropout=None,
        closed_form_image=False,
        image_update_freq=10,  # how many iters between image updates
        momentum_net=None,
        momentum_preconditioning=True,
        lddmm_integration_steps=5,
        learning_rate_pose=1e-5,
        learning_rate_image=1e6,
        resume_checkpoint=False,
        checkpoint_every=0,
        checkpoint_pattern='checkpoints/{epoch}.pth',
        gpu=None,
        world_size=1,
        rank=0):
    print(locals())
    from torch.utils.data import DataLoader, TensorDataset
    if gpu is None:
        device = 'cpu'
    else:
        device = f'cuda:{gpu}'
    if world_size > 1:
        sampler = DistributedSampler(dataset,
                                     num_replicas=world_size,
                                     rank=rank)
    else:
        sampler = None
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=batch_size,
                            num_workers=8,
                            pin_memory=True,
                            shuffle=False)
    #I = I.clone()
    if I0 is None:
        # initialize base image to mean
        I0 = batch_average(dataloader, dim=0)
    I0 = I0.view(1, 1, *I0.squeeze().shape)
    #I = DenseInterp(I0)
    #if gpu is not None:
    #I = DistributedDataParallel(I, device_ids=[gpu], output_device=gpu)
    #I = I.to(f'cuda:{gpu}')
    #else:
    epoch_losses = []
    iter_losses = []
    if momentum_net is None:
        momentum_net = MomentumPredictor(img_size=I0.shape, dropout=dropout)
    momentum_net = momentum_net.to(device)
    print(
        f"Momentum network has {sum([p.numel() for p in momentum_net.parameters()])} parameters"
    )
    if world_size > 1:
        momentum_net = DistributedDataParallel(momentum_net,
                                               device_ids=[gpu],
                                               output_device=gpu)
    start_epoch = 0
    if resume_checkpoint:
        while True:
            cpfile = checkpoint_pattern.format(epoch=start_epoch)
            if not os.is_file(cpfile):
                if start_epoch > 0:  # load previous state
                    cpfile = checkpoint_pattern.format(epoch=start_epoch - 1)
                    I, sd = torch.load(cpfile, map_location=device)
                    momentum_net.load_state_dict(sd)
                break
            start_epoch += 1
    I = I0.to(device).detach()
    from torch.nn.functional import mse_loss
    pose_optimizer = torch.optim.Adam(
        momentum_net.parameters(),
        # below we roughly compensate for scaling the loss
        # the goal is to have learning rates that are independent of
        # number and size of minibatches, but it's tricky to accomplish
        lr=learning_rate_pose * len(dataloader),
        weight_decay=1e-4)
    image_optimizer = torch.optim.SGD([I],
                                      lr=learning_rate_image,
                                      weight_decay=0)
    metric = lm.FluidMetric(fluid_params)
    epbar = range(start_epoch, num_epochs)
    if rank == 0:
        epbar = tqdm(epbar, desc='epoch')
    for epoch in epbar:
        epoch_loss = 0.0
        epoch_reg_term = 0.0
        itbar = dataloader
        if epoch > 1:  # start using gradients for image after one epoch
            closed_form_image = False
        if closed_form_image:
            splatI = torch.zeros_like(I)
            splatw = torch.zeros_like(I)
            splatI.requires_grad_(False)
            splatw.requires_grad_(False)
        if not closed_form_image and I.grad is not None:
            image_optimizer.zero_grad()
        I.requires_grad_(True)
        for it, (ix, img) in enumerate(itbar):
            pose_optimizer.zero_grad()
            img = img.detach().to(I.device)
            m = momentum_net(img)
            if momentum_preconditioning:
                m.register_hook(metric.flat)
            h = lm.expmap(metric, m, num_steps=lddmm_integration_steps)
            Idef = lm.interp(I, h)
            v = metric.sharp(m)
            reg_term = 0
            if reg_weight > 0:
                reg_term = reg_weight * (v * m).sum()
            loss = (mse_loss(Idef, img, reduction='sum') + reg_term) \
                    / (img.numel())
            loss.backward()
            li = (loss * (img.shape[0] / len(dataloader.dataset))).detach()
            epoch_loss += li
            ri = (reg_term *
                  (img.shape[0] /
                   (img.numel() * len(dataloader.dataset)))).detach()
            epoch_reg_term += ri
            iter_losses.append(li.item())
            #itbar.set_postfix(minibatch_loss=loss.item())
            pose_optimizer.step()
            del loss, reg_term, v, m, h, Idef, img
        with torch.no_grad():
            if world_size > 1:
                all_reduce(epoch_loss)
                all_reduce(epoch_reg_term)
                all_reduce(I.grad)
                I.grad = I.grad / world_size
            # average over iterations
            I.grad = I.grad / len(dataloader)
        image_optimizer.step()
        epoch_losses.append(epoch_loss.item())
        if rank == 0:
            epbar.set_postfix(epoch_loss=epoch_loss.item(),
                              epoch_reg=epoch_reg_term.item())
    return I.detach(), momentum_net, epoch_losses, iter_losses