def call(self, x, y): xs = x.unbind(0) ys = y.unbind(0) loss = 0 nb_channels = max(len(xs), len(ys)) # if len(xs) == 1: # xs = [xs[0]] * nb_channels # if len(ys) == 1: # ys = [ys[0]] * nb_channels for x, y in zip(xs, ys): x = x[None, None] y = y[None, None] loss += nn.MutualInfoLoss(patch_size=self.patch)(x, y) / nb_channels # I take the average of MI across channels to be consistent # with how MSE works. if self.factor != 1: loss = loss * self.factor return loss
def call(self, x, y): xs = x.unbind(0) ys = y.unbind(0) loss = 0 nb_channels = max(len(xs), len(ys)) # if len(xs) == 1: # xs = [xs[0]] * nb_channels # if len(ys) == 1: # ys = [ys[0]] * nb_channels for x, y in zip(xs, ys): x = x[None, None] y = y[None, None] mi = nn.MutualInfoLoss(nb_bins=self.bins, patch_size=self.patch, mask=[None, self.threshold], order=self.order, fwhm=self.fwhm) loss += mi(x, y) / nb_channels # I take the average of MI across channels to be consistent # with how MSE works. if self.factor != 1: loss = loss * self.factor return loss
def diffeo(source, target, group='SE', origin='center', image_loss=None, vel_loss=None, pull=None, optim_affine=True, max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None): """Diffeomorphic registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) The source (moving) image, with shape (batch, channel, *spatial). target : tensor or (tensor, affine) The target (fixed) image, with shape (batch, channel, *spatial). group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid' Affine sub-group to optimize. origin : {'native', 'center'}, default='center' Whether to rotate about the origin of the world-space ('native') or the center of the target field-of-view ('center'). When the origin of the world-space is far off (say you are registering smaller blocks cropped from a larger MRI), it can be beneficiary to rotate about the center of the FOV. image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss() A loss function that takestwo inputs of shape (batch, channel, *spatial). vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss() Either a factor to muultiply the bending loss with or a loss function that takes two inputs of shape (batch, channel, *spatial). pull : dict interpolation : int, default=1 Interpolation order bound : bound_like, default='dct2' Boundary condition extrapolate : bool, default=False Extrapolate out-of-bound data using the boundary conditions. max_iter : int, default=1000 Maximum number of iterations lr : float, default=0.1 Initial learning rate. min_lr : float, default=1e-7 Minimum learning rate. The optimization is stopped once this learning rate is reached. device : {'cpu', 'cuda', 'cuda:<id>'}, optional Backend to use init : ([batch], nb_prm) tensor_like, default=0 Initial guess for the affine parameters. Returns ------- q : (batch, nb_prm) tensor Parameters aff : (batch, D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (batch, *shape, D) tensor Initial velocity moved : tensor Source image moved to target space. """ group = affine_group_converter(group) pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] dim = source.dim() - 2 # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) min_lr = core.utils.make_list(min_lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) scheduler = ReduceLROnPlateau(optim) def forward(): moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) return loss_val # Optim loop loss_avg = 0 for n_iter in range(1, max_iter + 1): optim.zero_grad(set_to_none=True) loss_val = forward() loss_val.backward() optim.step(forward) with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 scheduler.step(loss_avg) print('{:4d} {:12.6f} | lr={:g} ' .format(n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 if (optim.param_groups[0]['lr'] < min_lr[0] and (len(optim.param_groups) == 1 or optim.param_groups[1]['lr'] < min_lr[1])): print('\nConverged.') break print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() return (parameters.detach(), aff.detach(), velocity.detach(), moved.detach())
def affine(source, target, group='SE', loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=0.1, scheduler=ReduceLROnPlateau): """Affine registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=True) identity = spatial.identity_grid(target.shape[2:], **backend) def pull(q): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], identity) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if loss is None: loss_fn = nni.MutualInfoLoss() loss = lambda x, y: loss_fn(x, y) optim = torch.optim.Adam([parameters], lr=lr) if scheduler is not None: scheduler = scheduler(optim) # Optim loop loss_val = core.constants.inf for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters) loss_val = loss(moved, target) loss_val.backward() optim.step() if scheduler is not None and n_iter % 10 == 0: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_val) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_val.item(), optim.param_groups[0]['lr']), end='\r') print('') with torch.no_grad(): moved = pull(parameters) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, moved
def diffeo(source, target, group='SE', image_loss=None, vel_loss=None, pull=None, preproc=False, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """ Parameters ---------- source : path or tensor or (tensor, affine) target : path or tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' image_loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr: float, default=1e-4 optim_affine : bool, default=True Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (D+1, D+1) tensor Initial velocity of the diffeomorphic transform. The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)` moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] source = rescale(source) targe = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, velocity, moved
def ffd(source, target, grid_shape=10, group='SE', image_loss=None, def_loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """FFD (= cubic spline) registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dft') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: affine_parameters = torch.as_tensor(init, **backend).clone().detach() affine_parameters = affine_parameters.reshape([batch, nb_prm]) else: affine_parameters = torch.zeros([batch, nb_prm], **backend) affine_parameters = nn.Parameter(affine_parameters, requires_grad=optim_affine) grid_shape = core.pyutils.make_list(grid_shape, dim) grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend) grid_parameters = nn.Parameter(grid_parameters, requires_grad=True) def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved def exp(prm): disp = spatial.resize_grid(prm, type='displacement', shape=target.shape[2:], interpolation=3, bound='dft') grid = disp + spatial.identity_grid(target.shape[2:], **backend) return disp, grid # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(def_loss): def_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if def_loss is None else def_loss def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{ 'params': affine_parameters, 'lr': lr[1] }, { 'params': grid_parameters, 'lr': lr[0] }] if optim_affine else [grid_parameters] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # with torch.no_grad(): # disp, grid = exp(grid_parameters) # moved = pull(affine_parameters, grid) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(max_iter): loss_val0 = loss_val zero_grad_([affine_parameters, grid_parameters]) disp, grid = exp(grid_parameters) moved = pull(affine_parameters, grid) loss_val = image_loss(moved, target) + def_loss(disp[0]) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: # print(affine_parameters) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(affine_parameters, grid) aff = core.linalg.expm(affine_parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return affine_parameters, aff, grid_parameters, moved