def forward(self, affine, **overload): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix overload : dict All parameters of the module can be overridden at call time. Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 info = {'dtype': affine.dtype, 'device': affine.device} shape = make_list(overload.get('shape', self.shape), nb_dim) shift = overload.get('shift', self.shift) if shift: affine_shift = torch.cat(( torch.eye(nb_dim, **info), -torch.as_tensor(shape, **info)[:, None]/2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def forward(self, affine, shape=None): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix shape : sequence[int], default=self.shape Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 backend = {'dtype': affine.dtype, 'device': affine.device} shape = shape or self.shape if self.shift: affine_shift = torch.eye(nb_dim + 1, **backend) affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend) affine_shift[:nb_dim, -1].sub(1).div(2).neg() affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def _warp_image1(image, target, shape=None, affine=None, nonlin=None, backward=False, reslice=False): """Returns the warped image, with channel dimension last""" # build transform aff_right = target aff_left = spatial.affine_inv(image.affine) aff = None if affine: # exp = affine.iexp if backward else affine.exp exp = affine.exp aff = exp(recompute=False, cache_result=True) if backward: aff = spatial.affine_inv(aff) if nonlin: if affine: if affine.position[0].lower() in ('ms' if backward else 'fs'): aff_right = spatial.affine_matmul(aff, aff_right) if affine.position[0].lower() in ('fs' if backward else 'ms'): aff_left = spatial.affine_matmul(aff_left, aff) exp = nonlin.iexp if backward else nonlin.exp phi = exp(recompute=False, cache_result=True) aff_left = spatial.affine_matmul(aff_left, nonlin.affine) aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right) if _almost_identity(aff_right) and nonlin.shape == shape: phi = nonlin.add_identity(phi) else: tmp = spatial.affine_grid(aff_right, shape) phi = regutils.smart_pull_grid(phi, tmp).add_(tmp) del tmp if not _almost_identity(aff_left): phi = spatial.affine_matvec(aff_left, phi) else: # no nonlin: single affine even if position == 'symmetric' if reslice: aff = spatial.affine_matmul(aff, aff_right) aff = spatial.affine_matmul(aff_left, aff) phi = spatial.affine_grid(aff, shape) else: phi = None # warp image if phi is not None: warped = image.pull(phi) else: warped = image.dat # write to disk if len(warped) == 1: warped = warped[0] else: warped = utils.movedim(warped, 0, -1) return warped
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch shape. Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- grid : (batch, *shape, 3) tensor Resampling grid """ shape = overload.get('shape', self.grid.velocity.field.shape) dtype = overload.get('dtype', self.grid.velocity.field.dtype) device = overload.get('device', self.grid.velocity.field.device) backend = dict(dtype=dtype, device=device) if self.grid.velocity.field.amplitude == 0: grid = identity_grid(shape, **backend) else: grid = self.grid(batch, shape=shape, **backend) dtype = grid.dtype device = grid.device backend = dict(dtype=dtype, device=device) shape = grid.shape[1:-1] dim = len(shape) aff = self.affine(batch, dim=dim, **backend) # shift center of rotation aff_shift = torch.cat(( torch.eye(dim, **backend), torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)), dim=1) aff_shift = as_euclidean(aff_shift) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = utils.unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def exp(self, velocity, affine=None, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field affine : (batch, nb_prm) Affine parameters displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacment). """ info = {'dtype': velocity.dtype, 'device': velocity.device} # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if affine is not None: # exponentiate affine_prm = affine affine = [] for prm in affine_prm: affine.append(self.affexp(prm)) affine = torch.stack(affine, dim=0) # shift center of rotation affine_shift = torch.cat( (torch.eye(self.dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) # compose affine = unsqueeze(affine, dim=-3, ndim=self.dim) lin = affine[..., :self.dim, :self.dim] off = affine[..., :self.dim, -1] grid = matvec(lin, grid) + off if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **info) return grid
def _crop_to_param(aff0, aff, shape): dim = aff0.shape[-1] - 1 shape = shape[:dim] layout0 = spatial.affine_to_layout(aff0) layout = spatial.affine_to_layout(aff) if (layout0 != layout).any(): raise ValueError('Input and Ref do not have the same layout: ' f'{spatial.volume_layout_to_name(layout0)} vs ' f'{spatial.volume_layout_to_name(layout)}.') size = shape layout = None unit = 'vox' center = torch.as_tensor(shape, dtype=torch.float).sub_(1).mul_(0.5) like_aff = spatial.affine_lmdiv(aff0, aff) center = spatial.affine_matvec(like_aff, center) return size, center, unit, layout
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid if options.target: # If target is provided, we build a dense transformation field grid = build_from_target(options.target) oaffine = options.target.affine if options.output_unit[0] == 'v': grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid) grid = grid - spatial.identity_grid(grid.shape[:-1], **utils.backend(grid)) else: grid = grid - spatial.affine_grid( oaffine.to(**utils.backend(grid)), grid.shape[:-1]) io.volumes.savef(grid, options.output.format(ext='.nii.gz'), affine=oaffine) else: if len(options.transformations) > 1: raise RuntimeError('Something weird happened: ' 'multiple transforms and no target') io.transforms.savef(options.transformations[0].affine, options.output.format(ext='.lta'))
def do_vel(self, vel, grad=False, hess=False, in_line_search=False): """Forward pass for updating the nonlinear component""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== if self.affine: aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False) aff_pos = self.affine.position[0].lower() else: aff_pos = 'x' aff0 = iaff0 = torch.eye(self.nonlin.dim + 1) vel0 = vel if any(loss.backward for loss in self.losses): phi0, iphi0 = self.nonlin.exp2(vel0, recompute=True, cache_result=not in_line_search) ivel0 = -vel0 else: phi0 = self.nonlin.exp(vel0, recompute=True, cache_result=not in_line_search) iphi0 = ivel0 = None aff0 = aff0.to(phi0) iaff0 = iaff0.to(phi0) # ============================================================== # ACCUMULATE DERIVATIVES # ============================================================== has_printed = False for loss in self.losses: # ========================================================== # ONE LOSS COMPONENT # ========================================================== moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, vel00 = iphi0, iaff0, ivel0 else: phi00, aff00, vel00 = phi0, aff0, vel0 # ---------------------------------------------------------- # build left and right affine # ---------------------------------------------------------- aff_right = fixed.affine if aff_pos in 'fs': # affine position: fixed or symmetric aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left = self.nonlin.affine if aff_pos in 'ms': # affine position: moving or symmetric aff_left = aff00 @ self.nonlin.affine aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: aff_right = None phi = spatial.add_identity_grid(phi00) disp = phi00 else: phi = spatial.affine_grid(aff_right, fixed.shape) disp = regutils.smart_pull_grid(phi00, phi) phi += disp if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: aff_left = None else: phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim, title=f'(nonlin) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt phi # ---------------------------------------------------------- if grad or hess: g, h, mugrad = self.nonlin.propagate_grad( g, h, moving, phi00, aff_left, aff_right, inv=loss.backward) g = regutils.jg(mugrad, g) h = regutils.jhj(mugrad, h) if isinstance(self.nonlin, SVFModel): # propagate backward by scaling and squaring g, h = spatial.exp_backward(vel00, g, h, steps=self.nonlin.steps) sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # ============================================================== # REGULARIZATION # ============================================================== vgrad = self.nonlin.regulariser(vel0) llv = 0.5 * vel0.flatten().dot(vgrad.flatten()) if grad: sumgrad += vgrad del vgrad # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += llv sumloss += self.lla self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): llv = llv.item() ll = sumloss.item() lla = self.lla if in_line_search: line = '(search) | ' else: line = '(nonlin) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.llv = llv self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def _warp_image(option, affine=None, nonlin=None, dim=None, device=None, odir=None): """Warp and save the moving and fixed images from a loss object""" if not (option.mov.output or option.mov.resliced or option.fix.output or option.fix.resliced): return fix, fix_affine = _map_image(option.fix.files, dim=dim) mov, mov_affine = _map_image(option.mov.files, dim=dim) fix_affine = fix_affine.float() mov_affine = mov_affine.float() dim = dim or (fix.dim - 1) if option.fix.world: # overwrite orientation matrix fix_affine = io.transforms.map(option.fix.world).fdata().squeeze() for transform in (option.fix.affine or []): transform = io.transforms.map(transform).fdata().squeeze() fix_affine = spatial.affine_lmdiv(transform, fix_affine) if option.mov.world: # overwrite orientation matrix mov_affine = io.transforms.map(option.mov.world).fdata().squeeze() for transform in (option.mov.affine or []): transform = io.transforms.map(transform).fdata().squeeze() mov_affine = spatial.affine_lmdiv(transform, mov_affine) # moving if option.mov.output or option.mov.resliced: ifname = option.mov.files[0] idir, base, ext = py.fileparts(ifname) odir_mov = odir or idir or '.' image = objects.Image(mov.fdata(rand=True, device=device), dim=dim, affine=mov_affine, bound=option.mov.bound, extrapolate=option.mov.extrapolate) if option.mov.output: target_affine = mov_affine target_shape = image.shape if affine and affine.position[0].lower() in 'ms': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_lmdiv(aff, target_affine) fname = option.mov.output.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.mov.resliced: target_affine = fix_affine target_shape = fix.shape[1:] fname = option.mov.resliced.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped # fixed if option.fix.output or option.fix.resliced: ifname = option.fix.files[0] idir, base, ext = py.fileparts(ifname) odir_fix = odir or idir or '.' image = objects.Image(fix.fdata(rand=True, device=device), dim=dim, affine=fix_affine, bound=option.fix.bound, extrapolate=option.fix.extrapolate) if option.fix.output: target_affine = fix_affine target_shape = image.shape if affine and affine.position[0].lower() in 'fs': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_matmul(aff, target_affine) fname = option.fix.output.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.fix.resliced: target_affine = mov_affine target_shape = mov.shape[1:] fname = option.fix.resliced.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped
def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: aff00, gaff00 = iaff0, igaff0 else: aff00, gaff00 = aff0, gaff0 # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- aff = aff00 @ fixed.affine aff = linalg.lmdiv(moving.affine, aff) gaff = gaff00 @ fixed.affine gaff = linalg.lmdiv(moving.affine, gaff) phi = spatial.affine_grid(aff, fixed.shape) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.preview else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.preview, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h # compose with spatial gradients if grad or hess: mugrad = moving.pull_grad(phi, rotate=False) g, h = compose_grad(g, h, mugrad, gaff) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla lla = lla ll = sumloss.item() self.loss_value = ll if self.verbose and (self.verbose > 1 or not in_line_search): if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def compose(self, orient_in, deformation, orient_mean, affine=None, orient_out=None, shape_out=None): """Composes a deformation defined in a mean space to an image space. Parameters ---------- orient_in : (4, 4) tensor Orientation of the input image deformation : (*shape_mean, 3) tensor Random deformation orient_mean : (4, 4) tensor Orientation of the mean space (where the deformation is) affine : (4, 4) tensor, default=identity Random affine orient_out : (4, 4) tensor, default=orient_in Orientation of the output image shape_out : sequence[int], default=shape_mean Shape of the output image Returns ------- grid : (*shape_out, 3) Voxel-to-voxel transform """ if orient_out is None: orient_out = orient_in if shape_out is None: shape_out = deformation.shape[:-1] if affine is None: affine = torch.eye(4, 4, device=orient_in.device, dtype=orient_in.dtype) shape_mean = deformation.shape[:-1] orient_in, affine, deformation, orient_mean, orient_out \ = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out) backend = utils.backend(deformation) eye = torch.eye(4, **backend) # Compose deformation on the right right_affine = spatial.affine_lmdiv(orient_mean, orient_out) if not (shape_mean == shape_out and right_affine.all_close(eye)): # the mean space and native space are not the same # we must compose the diffeo with a dense affine transform # we write the diffeo as an identity plus a displacement # (id + disp)(aff) = aff + disp(aff) # ------- # to displacement deformation = deformation - spatial.identity_grid( deformation.shape[:-1], **backend) trf = spatial.affine_grid(right_affine, shape_out) deformation = spatial.grid_pull(utils.movedim(deformation, -1, 0)[None], trf[None], bound='dft', extrapolate=True) deformation = utils.movedim(deformation[0], 0, -1) trf = trf + deformation # add displacement # Compose deformation on the left # the output of the diffeo(right) are mean_space voxels # we must compose on the left with `in\(aff(mean))` # ------- left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in), affine) left_affine = spatial.affine_matmul(left_affine, orient_mean) trf = spatial.affine_matvec(left_affine, trf) return trf
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.json: with open(trf.json) as f: prm = json.load(f) prm['voxel_size'] = spatial.voxel_size(trf.affine) trf.dat = spatial.shoot(trf.dat[None], displacement=True, return_inverse=trf.inv) if trf.inv: trf.dat = trf.dat[-1] else: trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv) trf.dat = trf.dat[0] # drop batch dimension trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(affine, shape): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(affine.to(**backend), shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target.affine, options.target.shape) oaffine = options.target.affine # 4) Loop across input files opt_pull0 = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) opt_coeff = dict(interpolation=options.interpolation, bound=options.bound, dim=3, inplace=True) output = py.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): is_label = isinstance(options.interpolation, str) and options.interpolation == 'l' ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') if is_label: backend_int = dict(dtype=torch.long, device=backend['device']) dat = io.volumes.load(file.fname, **backend_int) opt_pull = dict(opt_pull0) opt_pull['interpolation'] = 1 else: dat = io.volumes.loadf(file.fname, rand=options.interpolation > 0, **backend) opt_pull = opt_pull0 dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: oaffine = file.affine oshape = file.shape if options.voxel_size: ovx = utils.make_vector(options.voxel_size, 3, dtype=oaffine.dtype) factor = spatial.voxel_size(oaffine) / ovx oaffine, oshape = spatial.affine_resize(oaffine, oshape, factor=factor, anchor='f') grid = build_from_target(oaffine, oshape) mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) if options.prefilter and not is_label: dat = spatial.spline_coeff_nd(dat, **opt_coeff) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull) dat = utils.movedim(dat, 0, -1) if is_label: io.volumes.save(dat, ofname, like=file.fname, affine=oaffine) else: io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
def forward(self, image, **overload): """ Parameters ---------- image : (batch, channel, *shape) tensor Input image overload : dict All parameters defined at build time can be overridden at call time Returns ------- warped : (batch, channel, *shape) tensor Deformed image grid : (batch, *shape, 3) tensor Resampling grid """ image = torch.as_tensor(image) dim = image.dim() - 2 batch, channel, *shape = image.shape info = {'dtype': image.dtype, 'device': image.device} # get arguments opt_grid = { 'dim': dim, 'shape': shape, 'amplitude': overload.get('vel_amplitude', self.grid.amplitude), 'fwhm': overload.get('vel_fwhm', self.grid.fwhm), 'bound': overload.get('vel_bound', self.grid.bound), 'interpolation': overload.get('interpolation', self.grid.interpolation), 'dtype': overload.get('dtype', self.grid.dtype), 'device': overload.get('device', self.grid.device), } opt_affine = { 'dim': dim, 'translation': overload.get('translation', self.affine.translation), 'rotation': overload.get('rotation', self.affine.rotation), 'zoom': overload.get('zoom', self.affine.zoom), 'shear': overload.get('shear', self.affine.shear), 'dtype': overload.get('dtype', self.affine.dtype), 'device': overload.get('device', self.affine.device), } opt_pull = { 'bound': overload.get('image_bound', self.pull.bound), 'interpolation': overload.get('interpolation', self.pull.interpolation), } grid = self.grid(batch, **opt_grid) aff = self.affine(batch, **opt_affine) # shift center of rotation aff_shift = torch.cat( (torch.eye(dim, **info), -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2), dim=1) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = matvec(lin, grid) + off # pull warped = self.pull(image, grid, **opt_pull) return warped, grid
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None, interpolation=1, bound='dct2', extrapolate=False, device=None, verbose=True): """Apply the linear and non-linear components of the transform and reslice the image to the target space. Notes ----- .. The shape and general orientation of the moving image is kept untouched. .. The linear transform is composed with the original orientation matrix. .. The non-linear component is "wrapped" in the input space, where it is applied. .. This function writes a new file (it does not modify the input files in place). Parameters ---------- moving : ImageFile An object describing a moving image. fname : list of str Output filename for each input file of the moving image (since images can be encoded over multiple volumes) like : ImageFile An object describing the target space inv : bool, default=False True if we are warping the fixed image to the moving space. In the case, `moving` should be a `FixedImageFile` and `like` a `MovingImageFile`. Else it should be a `MovingImageFile` and `'like` a `FixedImageFile`. lin : (4, 4) tensor, optional Linear (or rather affine) transformation nonlin : dict, optional Non-linear displacement field, with keys: disp : (..., 3) tensor Displacement field (in voxels) affine : (4, 4) tensor Orientation matrix of the displacement field interpolation : int, default=1 bound : str, default='dct2' extrapolate : bool, default = False device : default='cpu' """ nonlin = nonlin or dict(disp=None, affine=None) prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate) moving_affine = moving.affine.to(device) fixed_affine = like.affine.to(device) if inv: # affine-corrected fixed space if lin is not None: fix2lin = affine_matmul(lin, fixed_affine) else: fix2lin = fixed_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param to moving nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(moving_affine, fix2lin) g = affine_grid(g, like.shape) else: # affine-corrected moving space if lin is not None: mov2nlin = affine_matmul(lin, moving_affine) else: mov2nlin = moving_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param voxels to moving voxels (warps moving to fixed) nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(mov2nlin, fixed_affine) g = affine_grid(g, like.shape) if moving.type == 'labels': prm['interpolation'] = 0 for file, ofname in zip(moving.files, fname): if verbose: print(f'Resliced: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, device=device) dat = dat.reshape([*file.shape, file.channels]) if g is not None: dat = utils.movedim(dat, -1, 0) dat = pull(dat, g, **prm) dat = utils.movedim(dat, 0, -1) io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target) oaffine = options.target.affine # 4) Loop across input files opt = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) output = utils.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, **backend) dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: grid = build_from_target(file) oaffine = file.affine mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt) dat = utils.movedim(dat, 0, -1) io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def _make_image(option, dim=None, device=None): """ Load an image and build a Gaussian pyramid (if requireD) Returns: ImagePyramid """ dat, mask, affine = _load_image(option.files, dim=dim, device=device, label=option.label) dim = dat.dim() - 1 if option.mask: mask1 = mask mask, _, _ = _load_image([option.mask], dim=dim, device=device, label=option.label) if mask.shape[-dim:] != dat.shape[-dim:]: raise ValueError('Mask should have the same shape as the image. ' f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}') if mask1 is not None: mask = mask * mask1 del mask1 if option.world: # overwrite orientation matrix affine = io.transforms.map(option.world).fdata().squeeze() for transform in (option.affine or []): transform = io.transforms.map(transform).fdata().squeeze() affine = spatial.affine_lmdiv(transform, affine) if not option.discretize and any(option.rescale): dat = _rescale_image(dat, option.rescale) if option.pad: pad = option.pad if isinstance(pad[-1], str): *pad, unit = pad else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) pad = torch.as_tensor(pad, **utils.backend(voxel_size)) pad = pad / voxel_size pad = pad.floor().int().tolist() else: pad = [int(p) for p in pad] pad = py.make_list(pad, dim) if any(pad): affine, _ = spatial.affine_pad(affine, dat.shape[-dim:], pad, side='both') dat = utils.pad(dat, pad, side='both', mode=option.bound) if mask is not None: mask = utils.pad(mask, pad, side='both', mode=option.bound) if option.fwhm: fwhm = option.fwhm if isinstance(fwhm[-1], str): *fwhm, unit = fwhm else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size)) fwhm = fwhm / voxel_size dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound) image = objects.ImagePyramid(dat, levels=option.pyramid, affine=affine, dim=dim, bound=option.bound, mask=mask, extrapolate=option.extrapolate, method=option.pyramid_method) if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1: for level in image: level.preview = level.dat level.dat = _soft_quantize_image(level.dat, option.soft_quantize) elif not option.label and option.discretize: for level in image: level.preview = level.dat level.dat = _discretize_image(level.dat, option.discretize) return image
def get_oriented_slice(image, dim=-1, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample a slice in a RAS system Parameters ---------- image : (..., *shape3) dim : int, default=-1 Index of spatial dimension to sample in the visualization space If RAS: -1 = axial / -2 = coronal / -3 = sagittal index : int, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : (..., *shape2) tensor Slice in the visualisation space. """ # preproc dim if isinstance(dim, str): dim = dim.lower()[0] if dim == 'a': dim = -1 if dim == 'c': dim = -2 if dim == 's': dim = -3 backend = utils.backend(image) # compute default space (mn/mx are in voxels) affine, shape = _get_default_native(affine, image.shape[-3:]) space, mn, mx = _get_default_space(affine, [shape], space, bbox) affine, shape = (affine[0], shape[0]) # compute default cursor (in voxels) if index is None: index = (mx + mn) / 2 else: index = torch.as_tensor(index) index = spatial.affine_matvec(spatial.affine_inv(space), index) # include slice to volume matrix shape = tuple(((mx-mn) + 1).round().int().tolist()) if dim == -1: # axial shift = [[1, 0, 0, - mn[0] + 1], [0, 1, 0, - mn[1] + 1], [0, 0, 1, - index[2]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = shape[:-1] index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1) elif dim == -2: # coronal shift = [[1, 0, 0, - mn[0] + 1], [0, 0, 1, - mn[2] + 1], [0, 1, 0, - index[1]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[0], shape[2]) index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1) elif dim == -3: # sagittal if not transpose_sagittal: shift = [[0, 0, 1, - mn[2] + 1], [0, 1, 0, - mn[1] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[2], shape[1]) index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1) else: shift = [[0, -1, 0, mx[1] + 1], [0, 0, 1, - mn[2] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[1], shape[2]) index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1) else: raise ValueError(f'Unknown dimension {dim}') # sample space = spatial.affine_rmdiv(space, shift) affine = spatial.affine_lmdiv(affine, space) affine = affine.to(**backend) grid = spatial.affine_grid(affine, [*shape, 1]) *channel, s0, s1, s2 = image.shape imshape = (s0, s1, s2) image = image.reshape([1, -1, *imshape]) image = spatial.grid_pull(image, grid[None], interpolation=interpolation, bound='dct2', extrapolate=False) image = image.reshape([*channel, *shape]) return ((image, index, space) if return_index and return_mat else (image, index) if return_index else (image, space) if return_mat else image)