def test_fluid_flat_gradcheck(bs, dim): fluid_params = [.1, .01, .001] defsh = tuple([bs, dim] + [res] * dim) v = torch.randn(defsh, dtype=torch.float64, requires_grad=True).cuda() metric = lm.FluidMetric(fluid_params) catch_gradcheck( f"Failed fluid flat gradcheck with batch size {bs} dim {dim}", metric.flat, (v, ))
def test_fluid_inverse(bs, dim): fluid_params = [.1, .01, .001] defsh = tuple([bs, dim] + [res] * dim) m = torch.randn(defsh, dtype=torch.float64, requires_grad=False).cuda() metric = lm.FluidMetric(fluid_params) v = metric.sharp(m) vm = metric.flat(v) assert torch.allclose( vm, m, atol=1e-3 ), f"Failed fluid inverse check with batch size {bs} dim {dim}"
def lddmm_matching(I, J, m=None, lddmm_steps=1000, lddmm_integration_steps=10, reg_weight=1e-1, learning_rate_pose=2e-2, fluid_params=[1.0, .1, .01], progress_bar=True): """Matching image I to J via LDDMM""" if m is None: defsh = [I.shape[0], 3] + list(I.shape[2:]) m = torch.zeros(defsh, dtype=I.dtype).to(I.device) do_regridding = m.shape[2:] != I.shape[2:] J = J.to(I.device) matchterms = [] regterms = [] losses = [] metric = lm.FluidMetric(fluid_params) m.requires_grad_() pb = range(lddmm_steps) if progress_bar: pb = tqdm(pb) for mit in pb: if m.grad is not None: m.grad.detach_() m.grad.zero_() m.requires_grad_() h = lm.expmap(metric, m, num_steps=lddmm_integration_steps) if do_regridding is not None: h = lm.regrid(h, shape=I.shape[2:], displacement=True) Idef = lm.interp(I, h) regterm = (metric.sharp(m) * m).mean() matchterm = mse_loss(Idef, J) matchterms.append(matchterm.detach().item()) regterms.append(regterm.detach().item()) loss = matchterm + reg_weight * regterm loss.backward() loss.detach_() with torch.no_grad(): #v = metric.sharp(m) #regterm = (v*m).mean()#.detach() #del v #losses.append(loss.detach()+ .5*reg_weight*regterm) losses.append(loss.detach()) p = metric.flat(m.grad).detach() if torch.isnan(losses[-1]).item(): print(f"loss is NaN at iter {mit}") break #if mit > 0 and losses[-1].item() > losses[-2].item(): # print(f"loss increased at iter {mit}") #p.add_(reg_weight/np.prod(m.shape[1:]), m) m.add_(-learning_rate_pose, p) return m.detach(), [l.item() for l in losses], matchterms, regterms
_, I = oasis_ds_std[0] _, J = oasis_ds_std[10] I = I.unsqueeze(1).to('cuda') J = J.unsqueeze(1).to(I.device) I.requires_grad_(False) J.requires_grad_(False) #fluid_params=[1e-2,.0,.01] fluid_params = [5e-2, .0, .01] diffeo_scale = None mmatch, losses_match = lddmm_matching(I, J, fluid_params=fluid_params, diffeo_scale=diffeo_scale) if args.plot: sl = I.shape[3] // 2 metric = lm.FluidMetric(fluid_params) hsmall = lm.expmap(metric, mmatch, num_steps=10) h = hsmall if diffeo_scale is not None: h = lm.regrid(h, shape=I.shape[2:], displacement=True) Idef = lm.interp(I, h) plt.plot(losses_match, '-o') plt.figure() plt.subplot(1, 3, 1) plt.imshow(I[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.subplot(1, 3, 2) plt.imshow(Idef[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.subplot(1, 3, 3) plt.imshow(J[0, 0, :, sl, :].cpu().numpy().squeeze(), cmap='gray') plt.figure() if diffeo_scale is None:
def test_expmap_zero(bs, dim, step, params): defsh = tuple([bs, dim] + [res] * dim) m = torch.zeros(defsh, dtype=torch.float64, requires_grad=False).cuda() metric = lm.FluidMetric(params) h = lm.expmap(metric, m, num_steps=step) assert torch.allclose(m, h), "Failed expmap of zero is identity check"
def lddmm_atlas(dataset, I0=None, num_epochs=500, batch_size=10, lddmm_steps=1, lddmm_integration_steps=5, reg_weight=1e2, learning_rate_pose = 2e2, learning_rate_image = 1e4, fluid_params=[0.1,0.,.01], device='cuda', momentum_preconditioning=True, momentum_pattern='oasis_momenta/momentum_{}.pth', gpu=None, world_size=1, rank=0): if world_size > 1: sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) else: sampler = None dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=False) if I0 is None: # initialize base image to mean I0 = batch_average(dataloader, dim=0) I0 = I0.view(1,1,*I0.squeeze().shape) #I = DenseInterp(I0) #if gpu is not None: #I = DistributedDataParallel(I, device_ids=[gpu], output_device=gpu) #I = I.to(f'cuda:{gpu}') #else: I = I0.clone() I = I.to(device) #image_optimizer = torch.optim.SGD(I.parameters(), image_optimizer = torch.optim.SGD([I], lr=learning_rate_image, weight_decay=0) metric = lm.FluidMetric(fluid_params) losses = [] reg_terms = [] iter_losses = [] epbar = range(num_epochs) if rank == 0: epbar = tqdm(epbar, desc='epoch') ms = torch.zeros(len(dataset),3,*I0.shape[-3:], dtype=I0.dtype).pin_memory() for epoch in epbar: epoch_loss = 0.0 epoch_reg_term = 0.0 itbar = dataloader I.requires_grad_(True) image_optimizer.zero_grad() for it, (ix, img) in enumerate(itbar): m = ms[ix,...].detach() m = m.to(device) img = img.to(device) for lit in range(lddmm_steps): # compute image gradient in last step I.requires_grad_(lit == lddmm_steps - 1) # enables taking multiple LDDMM step per image update m.requires_grad_(True) if m.grad is not None: m.grad.detach_() m.grad.zero_() h = lm.expmap(metric, m, num_steps=lddmm_integration_steps) #Idef = I(h) Idef = lm.interp(I, h) v = metric.sharp(m) regterm = reg_weight*(v*m).sum() loss = (mse_loss(Idef, img, reduction='sum') + regterm) \ / (img.numel()) loss.backward() # this makes it so that we can reduce the loss and eventually get # an accurate MSE for the entire dataset with torch.no_grad(): li = (loss*(img.shape[0]/len(dataloader.dataset))).detach() p = m.grad if momentum_preconditioning: p = metric.flat(p) m.add_(-learning_rate_pose, p) if world_size > 1: all_reduce(li) iter_losses.append(li.item()) m = m.detach() del p with torch.no_grad(): epoch_loss += li ri = (regterm*(img.shape[0]/(img.numel()*len(dataloader.dataset)))).detach() epoch_reg_term += ri ms[ix,...] = m.detach().cpu() del m, h, Idef, v, loss, regterm, img with torch.no_grad(): if world_size > 1: all_reduce(epoch_loss) all_reduce(epoch_reg_term) all_reduce(I.grad) I.grad = I.grad/world_size # average over iterations I.grad = I.grad / len(dataloader) image_optimizer.step() losses.append(epoch_loss.item()) reg_terms.append(epoch_reg_term.item()) if rank == 0: epbar.set_postfix(epoch_loss=epoch_loss.item(), epoch_reg=epoch_reg_term.item()) #return I.state_dict()['I'].detach(), ms.detach(), losses, iter_losses return I.detach(), ms.detach(), losses, iter_losses
def deep_lddmm_atlas( dataset, I0=None, fluid_params=[1e-1, 0., .01], num_epochs=500, batch_size=2, reg_weight=.001, dropout=None, closed_form_image=False, image_update_freq=10, # how many iters between image updates momentum_net=None, momentum_preconditioning=True, lddmm_integration_steps=5, learning_rate_pose=1e-5, learning_rate_image=1e6, resume_checkpoint=False, checkpoint_every=0, checkpoint_pattern='checkpoints/{epoch}.pth', gpu=None, world_size=1, rank=0): print(locals()) from torch.utils.data import DataLoader, TensorDataset if gpu is None: device = 'cpu' else: device = f'cuda:{gpu}' if world_size > 1: sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) else: sampler = None dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=False) #I = I.clone() if I0 is None: # initialize base image to mean I0 = batch_average(dataloader, dim=0) I0 = I0.view(1, 1, *I0.squeeze().shape) #I = DenseInterp(I0) #if gpu is not None: #I = DistributedDataParallel(I, device_ids=[gpu], output_device=gpu) #I = I.to(f'cuda:{gpu}') #else: epoch_losses = [] iter_losses = [] if momentum_net is None: momentum_net = MomentumPredictor(img_size=I0.shape, dropout=dropout) momentum_net = momentum_net.to(device) print( f"Momentum network has {sum([p.numel() for p in momentum_net.parameters()])} parameters" ) if world_size > 1: momentum_net = DistributedDataParallel(momentum_net, device_ids=[gpu], output_device=gpu) start_epoch = 0 if resume_checkpoint: while True: cpfile = checkpoint_pattern.format(epoch=start_epoch) if not os.is_file(cpfile): if start_epoch > 0: # load previous state cpfile = checkpoint_pattern.format(epoch=start_epoch - 1) I, sd = torch.load(cpfile, map_location=device) momentum_net.load_state_dict(sd) break start_epoch += 1 I = I0.to(device).detach() from torch.nn.functional import mse_loss pose_optimizer = torch.optim.Adam( momentum_net.parameters(), # below we roughly compensate for scaling the loss # the goal is to have learning rates that are independent of # number and size of minibatches, but it's tricky to accomplish lr=learning_rate_pose * len(dataloader), weight_decay=1e-4) image_optimizer = torch.optim.SGD([I], lr=learning_rate_image, weight_decay=0) metric = lm.FluidMetric(fluid_params) epbar = range(start_epoch, num_epochs) if rank == 0: epbar = tqdm(epbar, desc='epoch') for epoch in epbar: epoch_loss = 0.0 epoch_reg_term = 0.0 itbar = dataloader if epoch > 1: # start using gradients for image after one epoch closed_form_image = False if closed_form_image: splatI = torch.zeros_like(I) splatw = torch.zeros_like(I) splatI.requires_grad_(False) splatw.requires_grad_(False) if not closed_form_image and I.grad is not None: image_optimizer.zero_grad() I.requires_grad_(True) for it, (ix, img) in enumerate(itbar): pose_optimizer.zero_grad() img = img.detach().to(I.device) m = momentum_net(img) if momentum_preconditioning: m.register_hook(metric.flat) h = lm.expmap(metric, m, num_steps=lddmm_integration_steps) Idef = lm.interp(I, h) v = metric.sharp(m) reg_term = 0 if reg_weight > 0: reg_term = reg_weight * (v * m).sum() loss = (mse_loss(Idef, img, reduction='sum') + reg_term) \ / (img.numel()) loss.backward() li = (loss * (img.shape[0] / len(dataloader.dataset))).detach() epoch_loss += li ri = (reg_term * (img.shape[0] / (img.numel() * len(dataloader.dataset)))).detach() epoch_reg_term += ri iter_losses.append(li.item()) #itbar.set_postfix(minibatch_loss=loss.item()) pose_optimizer.step() del loss, reg_term, v, m, h, Idef, img with torch.no_grad(): if world_size > 1: all_reduce(epoch_loss) all_reduce(epoch_reg_term) all_reduce(I.grad) I.grad = I.grad / world_size # average over iterations I.grad = I.grad / len(dataloader) image_optimizer.step() epoch_losses.append(epoch_loss.item()) if rank == 0: epbar.set_postfix(epoch_loss=epoch_loss.item(), epoch_reg=epoch_reg_term.item()) return I.detach(), momentum_net, epoch_losses, iter_losses