def _warp_image1(image, target, shape=None, affine=None, nonlin=None, backward=False, reslice=False): """Returns the warped image, with channel dimension last""" # build transform aff_right = target aff_left = spatial.affine_inv(image.affine) aff = None if affine: # exp = affine.iexp if backward else affine.exp exp = affine.exp aff = exp(recompute=False, cache_result=True) if backward: aff = spatial.affine_inv(aff) if nonlin: if affine: if affine.position[0].lower() in ('ms' if backward else 'fs'): aff_right = spatial.affine_matmul(aff, aff_right) if affine.position[0].lower() in ('fs' if backward else 'ms'): aff_left = spatial.affine_matmul(aff_left, aff) exp = nonlin.iexp if backward else nonlin.exp phi = exp(recompute=False, cache_result=True) aff_left = spatial.affine_matmul(aff_left, nonlin.affine) aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right) if _almost_identity(aff_right) and nonlin.shape == shape: phi = nonlin.add_identity(phi) else: tmp = spatial.affine_grid(aff_right, shape) phi = regutils.smart_pull_grid(phi, tmp).add_(tmp) del tmp if not _almost_identity(aff_left): phi = spatial.affine_matvec(aff_left, phi) else: # no nonlin: single affine even if position == 'symmetric' if reslice: aff = spatial.affine_matmul(aff, aff_right) aff = spatial.affine_matmul(aff_left, aff) phi = spatial.affine_grid(aff, shape) else: phi = None # warp image if phi is not None: warped = image.pull(phi) else: warped = image.dat # write to disk if len(warped) == 1: warped = warped[0] else: warped = utils.movedim(warped, 0, -1) return warped
def forward(self, affine, **overload): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix overload : dict All parameters of the module can be overridden at call time. Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 info = {'dtype': affine.dtype, 'device': affine.device} shape = make_list(overload.get('shape', self.shape), nb_dim) shift = overload.get('shift', self.shift) if shift: affine_shift = torch.cat(( torch.eye(nb_dim, **info), -torch.as_tensor(shape, **info)[:, None]/2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def forward(self, affine, shape=None): """ Parameters ---------- affine : (batch, ndim[+1], ndim+1) tensor Affine matrix shape : sequence[int], default=self.shape Returns ------- grid : (batch, *shape, ndim) tensor Dense transformation grid """ nb_dim = affine.shape[-1] - 1 backend = {'dtype': affine.dtype, 'device': affine.device} shape = shape or self.shape if self.shift: affine_shift = torch.eye(nb_dim + 1, **backend) affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend) affine_shift[:nb_dim, -1].sub(1).div(2).neg() affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) grid = spatial.affine_grid(affine, shape) return grid
def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def set_fdata(self, affine): affine = torch.as_tensor(affine) backend = dict(dtype=affine.dtype, device=affine.device) if affine.shape[-2:] != (4, 4): raise ValueError('Expected a batch of 4x4 matrix') # we may need to convert from RAS space to a weird space afftype = self.type()[0] if afftype != 'ras': src, _ = self.source_space(afftype, 'ras', **backend) dst, _ = self.destination_space('ras', afftype, **backend) if src is not None and dst is not None: affine = affine_matmul(dst, affine_matmul(affine, src)) affine = np.asarray(affine).reshape([-1, 4, 4]) self._struct.affine = affine self._struct.nxform = affine.shape[0] return self
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch shape. Other Parameters ---------------- shape : sequence[int], optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- grid : (batch, *shape, 3) tensor Resampling grid """ shape = overload.get('shape', self.grid.velocity.field.shape) dtype = overload.get('dtype', self.grid.velocity.field.dtype) device = overload.get('device', self.grid.velocity.field.device) backend = dict(dtype=dtype, device=device) if self.grid.velocity.field.amplitude == 0: grid = identity_grid(shape, **backend) else: grid = self.grid(batch, shape=shape, **backend) dtype = grid.dtype device = grid.device backend = dict(dtype=dtype, device=device) shape = grid.shape[1:-1] dim = len(shape) aff = self.affine(batch, dim=dim, **backend) # shift center of rotation aff_shift = torch.cat(( torch.eye(dim, **backend), torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)), dim=1) aff_shift = as_euclidean(aff_shift) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = utils.unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = linalg.matvec(lin, grid) + off return grid
def fdata(self, dtype=None, device=None, numpy=False): dtype = dtype or torch.get_default_dtype() backend = dict(dtype=dtype, device=device) affine = self.data(**backend) if affine is None: return None # we may need to convert from a weird space to RAS space afftype = self.type()[0] if afftype != 'ras': src, _ = self.source_space('ras', afftype, **backend) dst, _ = self.destination_space(afftype, 'ras', **backend) if src is not None and dst is not None: affine = affine_matmul(dst, affine_matmul(affine, src)) affine = cast(affine, dtype) if numpy: return np.asarray(affine) else: return torch.as_tensor(affine, dtype=dtype, device=device)
def exp(self, velocity, affine=None, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field affine : (batch, nb_prm) Affine parameters displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacment). """ info = {'dtype': velocity.dtype, 'device': velocity.device} # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity, type='displacement') grid = self.velexp(velocity_small) grid = self.resize(grid, shape=shape, type='grid') if affine is not None: # exponentiate affine_prm = affine affine = [] for prm in affine_prm: affine.append(self.affexp(prm)) affine = torch.stack(affine, dim=0) # shift center of rotation affine_shift = torch.cat( (torch.eye(self.dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine = spatial.affine_matmul(affine, affine_shift) affine = spatial.affine_lmdiv(affine_shift, affine) # compose affine = unsqueeze(affine, dim=-3, ndim=self.dim) lin = affine[..., :self.dim, :self.dim] off = affine[..., :self.dim, -1] grid = matvec(lin, grid) + off if displacement: grid = grid - spatial.identity_grid(grid.shape[1:-1], **info) return grid
def affine_to_fs(affine, shape, source='voxel', dest='ras'): """Convert an affine matrix into FS parameters (vx/cosine/shift) Parameters ---------- affine : (4, 4) tensor shape : (int, int, int) source : {'voxel', 'physical', 'ras'}, default='voxel' dest : {'voxel', 'physical', 'ras'}, default='ras' Returns ------- voxel_size : (float, float, float) x : (float, float, float) y : (float, float, float) z: (float, float, float) c : (float, float, float) """ affine = torch.as_tensor(affine) backend = dict(dtype=affine.dtype, device=affine.device) vx = get_voxel_size(affine) shape = torch.as_tensor(shape, **backend) source = source.lower()[0] dest = dest.lower()[0] shift = shape / 2. shift = -shift * vx vox2phys = Orientation(shift, vx).affine() if (source, dest) in (('v', 'p'), ('p', 'v')): phys2ras = torch.eye(4, **backend) elif (source, dest) in (('v', 'r'), ('r', 'v')): if source == 'r': affine = affine_inv(affine) phys2vox = affine_inv(vox2phys) phys2ras = affine_matmul(affine, phys2vox) else: assert (source, dest) in (('p', 'r'), ('r', 'p')) if source == 'r': affine = affine_inv(affine) phys2ras = affine phys2ras = HomogeneousAffineMatrix(phys2ras) return (vx.tolist(), phys2ras.xras().tolist(), phys2ras.yras().tolist(), phys2ras.zras().tolist(), phys2ras.cras().tolist())
def collapse_transforms(options): """Pre-invert affines and combine sequential affines""" trfs = [] last_trf = None for trf in options.transformations: if isinstance(trf, Linear): if trf.inv: trf.affine = spatial.affine_inv(trf.affine) trf.inv = False if isinstance(last_trf, Linear): last_trf.affine = spatial.affine_matmul( last_trf.affine, trf.affine) else: last_trf = trf else: if isinstance(last_trf, Linear): trfs.append(last_trf) last_trf = None trfs.append(trf) if isinstance(last_trf, Linear): trfs.append(last_trf) options.transformations = trfs
def fs_to_affine(shape, voxel_size=1., x=None, y=None, z=None, c=0., source='voxel', dest='ras'): """Transform FreeSurfer orientation parameters into an affine matrix. The returned matrix is effectively a "<source> to <dest>" transform. Parameters ---------- shape : sequence of int voxel_size : [sequence of] float, default=1 x : [sequence of] float, default=[1, 0, 0] y: [sequence of] float, default=[0, 1, 0] z: [sequence of] float, default=[0, 0, 1] c: [sequence of] float, default=0 source : {'voxel', 'physical', 'ras'}, default='voxel' dest : {'voxel', 'physical', 'ras'}, default='ras' Returns ------- affine : (4, 4) tensor """ dim = len(shape) shape, voxel_size, x, y, z, c \ = utils.to_max_backend(shape, voxel_size, x, y, z, c) backend = dict(dtype=shape.dtype, device=shape.device) voxel_size = utils.make_vector(voxel_size, dim) if x is None: x = [1, 0, 0] if y is None: y = [0, 1, 0] if z is None: z = [0, 0, 1] x = utils.make_vector(x, dim) y = utils.make_vector(y, dim) z = utils.make_vector(z, dim) c = utils.make_vector(c, dim) shift = shape / 2. shift = -shift * voxel_size vox2phys = Orientation(shift, voxel_size).affine() phys2ras = XYZC(x, y, z, c).affine() affines = [] if source.lower().startswith('vox'): affines.append(vox2phys) middle_space = 'phys' elif source.lower().startswith('phys'): if dest.lower().startswith('vox'): affines.append(affine_inv(vox2phys)) middle_space = 'vox' else: affines.append(phys2ras) middle_space = 'ras' elif source.lower() == 'ras': affines.append(affine_inv(phys2ras)) middle_space = 'phys' else: # We need a matrix to switch orientations affines.append(layout_matrix(source, **backend)) middle_space = 'ras' if dest.lower().startswith('phys'): if middle_space == 'vox': affines.append(vox2phys) elif middle_space == 'ras': affines.append(affine_inv(phys2ras)) elif dest.lower().startswith('vox'): if middle_space == 'phys': affines.append(affine_inv(vox2phys)) elif middle_space == 'ras': affines.append(affine_inv(phys2ras)) affines.append(affine_inv(vox2phys)) elif dest.lower().startswith('ras'): if middle_space == 'phys': affines.append(phys2ras) elif middle_space.lower().startswith('vox'): affines.append(vox2phys) affines.append(phys2ras) else: if middle_space == 'phys': affines.append(affine_inv(phys2ras)) elif middle_space == 'vox': affines.append(vox2phys) affines.append(phys2ras) layout = layout_matrix(dest, **backend) affines.append(affine_inv(layout)) affine, *affines = affines for aff in affines: affine = affine_matmul(aff, affine) return affine
def forward(self, grid, **overload): """ Parameters ---------- grid : (N, *spatial, dim) Displacement grid overload : dict Returns ------- aff : (N, dim+1, dim+1) Affine matrix that is closest to grid in the least square sense """ shift = overload.get('shift', self.shift) grid = torch.as_tensor(grid) info = dict(dtype=grid.dtype, device=grid.device) nb_dim = grid.shape[-1] shape = grid.shape[1:-1] if shift: affine_shift = torch.cat((torch.eye( nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine_shift = spatial.as_euclidean(affine_shift) # the forward model is: # phi(x) = M\A*M*x # where phi is a *transformation* field, M is the shift matrix # and A is the affine matrix. # We can decompose phi(x) = x + d(x), where d is a *displacement* # field, yielding: # d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D` # and `G`, we have: # D = G*B' # Therefore, the least squares B is obtained as: # B' = inv(G'*G) * (G'*D) # Then, A is # A = M*(B + I)/M # # Finally, we project the affine matrix to its tangent space: # prm[k] = <log(A), B[k]> # were <X,Y> = trace(X'*Y) is the Frobenius inner product. def igg(identity): # Compute inv(g*g'), where g has homogeneous coordinates. # Instead of appending ones, we compute each element of # the block matrix ourselves: # [[g'*g, g'*1], # [1'*g, 1'*1]] # where 1'*1 = N, the number of voxels. g = identity.reshape([identity.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) gg = torch.matmul(g.transpose(-1, -2), g) gg = torch.cat((gg, sumg), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) gg = torch.cat((gg, sumg), dim=2) return gg.inverse() def gd(identity, disp): # compute g'*d, where g and d have homogeneous coordinates. # [[g'*d, g'*1], # [1'*d, 1'*1]] g = identity.reshape([identity.shape[0], -1, nb_dim]) d = disp.reshape([disp.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) sumd = d.sum(dim=1, keepdim=True) gd = torch.matmul(g.transpose(-1, -2), d) gd = torch.cat((gd, sumd), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]]) gd = torch.cat((gd, sumg), dim=2) return gd def eye(d): x = torch.eye(d, **info) z = x.new_zeros([1, d], **info) x = torch.cat((x, z), dim=0) z = x.new_zeros([d + 1, 1], **info) x = torch.cat((x, z), dim=1) return x identity = spatial.identity_grid(shape, **info)[None, ...] affine = torch.matmul(igg(identity), gd(identity, grid)) affine = affine.transpose(-1, -2) + eye(nb_dim) affine = affine[..., :-1, :] if shift: affine = spatial.as_euclidean(affine) affine = spatial.affine_matmul(affine_shift, affine) affine = spatial.as_euclidean(affine) affine = spatial.affine_rmdiv(affine, affine_shift) affine = spatial.affine_make_square(affine) return affine
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss
def _warp_image(option, affine=None, nonlin=None, dim=None, device=None, odir=None): """Warp and save the moving and fixed images from a loss object""" if not (option.mov.output or option.mov.resliced or option.fix.output or option.fix.resliced): return fix, fix_affine = _map_image(option.fix.files, dim=dim) mov, mov_affine = _map_image(option.mov.files, dim=dim) fix_affine = fix_affine.float() mov_affine = mov_affine.float() dim = dim or (fix.dim - 1) if option.fix.world: # overwrite orientation matrix fix_affine = io.transforms.map(option.fix.world).fdata().squeeze() for transform in (option.fix.affine or []): transform = io.transforms.map(transform).fdata().squeeze() fix_affine = spatial.affine_lmdiv(transform, fix_affine) if option.mov.world: # overwrite orientation matrix mov_affine = io.transforms.map(option.mov.world).fdata().squeeze() for transform in (option.mov.affine or []): transform = io.transforms.map(transform).fdata().squeeze() mov_affine = spatial.affine_lmdiv(transform, mov_affine) # moving if option.mov.output or option.mov.resliced: ifname = option.mov.files[0] idir, base, ext = py.fileparts(ifname) odir_mov = odir or idir or '.' image = objects.Image(mov.fdata(rand=True, device=device), dim=dim, affine=mov_affine, bound=option.mov.bound, extrapolate=option.mov.extrapolate) if option.mov.output: target_affine = mov_affine target_shape = image.shape if affine and affine.position[0].lower() in 'ms': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_lmdiv(aff, target_affine) fname = option.mov.output.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.mov.resliced: target_affine = fix_affine target_shape = fix.shape[1:] fname = option.mov.resliced.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped # fixed if option.fix.output or option.fix.resliced: ifname = option.fix.files[0] idir, base, ext = py.fileparts(ifname) odir_fix = odir or idir or '.' image = objects.Image(fix.fdata(rand=True, device=device), dim=dim, affine=fix_affine, bound=option.fix.bound, extrapolate=option.fix.extrapolate) if option.fix.output: target_affine = fix_affine target_shape = image.shape if affine and affine.position[0].lower() in 'fs': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_matmul(aff, target_affine) fname = option.fix.output.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.fix.resliced: target_affine = mov_affine target_shape = mov.shape[1:] fname = option.fix.resliced.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None, interpolation=1, bound='dct2', extrapolate=False, device=None, verbose=True): """Apply the linear and non-linear components of the transform and reslice the image to the target space. Notes ----- .. The shape and general orientation of the moving image is kept untouched. .. The linear transform is composed with the original orientation matrix. .. The non-linear component is "wrapped" in the input space, where it is applied. .. This function writes a new file (it does not modify the input files in place). Parameters ---------- moving : ImageFile An object describing a moving image. fname : list of str Output filename for each input file of the moving image (since images can be encoded over multiple volumes) like : ImageFile An object describing the target space inv : bool, default=False True if we are warping the fixed image to the moving space. In the case, `moving` should be a `FixedImageFile` and `like` a `MovingImageFile`. Else it should be a `MovingImageFile` and `'like` a `FixedImageFile`. lin : (4, 4) tensor, optional Linear (or rather affine) transformation nonlin : dict, optional Non-linear displacement field, with keys: disp : (..., 3) tensor Displacement field (in voxels) affine : (4, 4) tensor Orientation matrix of the displacement field interpolation : int, default=1 bound : str, default='dct2' extrapolate : bool, default = False device : default='cpu' """ nonlin = nonlin or dict(disp=None, affine=None) prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate) moving_affine = moving.affine.to(device) fixed_affine = like.affine.to(device) if inv: # affine-corrected fixed space if lin is not None: fix2lin = affine_matmul(lin, fixed_affine) else: fix2lin = fixed_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param to moving nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(moving_affine, fix2lin) g = affine_grid(g, like.shape) else: # affine-corrected moving space if lin is not None: mov2nlin = affine_matmul(lin, moving_affine) else: mov2nlin = moving_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param voxels to moving voxels (warps moving to fixed) nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(mov2nlin, fixed_affine) g = affine_grid(g, like.shape) if moving.type == 'labels': prm['interpolation'] = 0 for file, ofname in zip(moving.files, fname): if verbose: print(f'Resliced: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, device=device) dat = dat.reshape([*file.shape, file.channels]) if g is not None: dat = utils.movedim(dat, -1, 0) dat = pull(dat, g, **prm) dat = utils.movedim(dat, 0, -1) io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def forward(self, image, **overload): """ Parameters ---------- image : (batch, channel, *shape) tensor Input image overload : dict All parameters defined at build time can be overridden at call time Returns ------- warped : (batch, channel, *shape) tensor Deformed image grid : (batch, *shape, 3) tensor Resampling grid """ image = torch.as_tensor(image) dim = image.dim() - 2 batch, channel, *shape = image.shape info = {'dtype': image.dtype, 'device': image.device} # get arguments opt_grid = { 'dim': dim, 'shape': shape, 'amplitude': overload.get('vel_amplitude', self.grid.amplitude), 'fwhm': overload.get('vel_fwhm', self.grid.fwhm), 'bound': overload.get('vel_bound', self.grid.bound), 'interpolation': overload.get('interpolation', self.grid.interpolation), 'dtype': overload.get('dtype', self.grid.dtype), 'device': overload.get('device', self.grid.device), } opt_affine = { 'dim': dim, 'translation': overload.get('translation', self.affine.translation), 'rotation': overload.get('rotation', self.affine.rotation), 'zoom': overload.get('zoom', self.affine.zoom), 'shear': overload.get('shear', self.affine.shear), 'dtype': overload.get('dtype', self.affine.dtype), 'device': overload.get('device', self.affine.device), } opt_pull = { 'bound': overload.get('image_bound', self.pull.bound), 'interpolation': overload.get('interpolation', self.pull.interpolation), } grid = self.grid(batch, **opt_grid) aff = self.affine(batch, **opt_affine) # shift center of rotation aff_shift = torch.cat( (torch.eye(dim, **info), -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2), dim=1) aff = affine_matmul(aff, aff_shift) aff = affine_lmdiv(aff_shift, aff) # compose aff = unsqueeze(aff, dim=-3, ndim=dim) lin = aff[..., :dim, :dim] off = aff[..., :dim, -1] grid = matvec(lin, grid) + off # pull warped = self.pull(image, grid, **opt_pull) return warped, grid
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def compose(self, orient_in, deformation, orient_mean, affine=None, orient_out=None, shape_out=None): """Composes a deformation defined in a mean space to an image space. Parameters ---------- orient_in : (4, 4) tensor Orientation of the input image deformation : (*shape_mean, 3) tensor Random deformation orient_mean : (4, 4) tensor Orientation of the mean space (where the deformation is) affine : (4, 4) tensor, default=identity Random affine orient_out : (4, 4) tensor, default=orient_in Orientation of the output image shape_out : sequence[int], default=shape_mean Shape of the output image Returns ------- grid : (*shape_out, 3) Voxel-to-voxel transform """ if orient_out is None: orient_out = orient_in if shape_out is None: shape_out = deformation.shape[:-1] if affine is None: affine = torch.eye(4, 4, device=orient_in.device, dtype=orient_in.dtype) shape_mean = deformation.shape[:-1] orient_in, affine, deformation, orient_mean, orient_out \ = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out) backend = utils.backend(deformation) eye = torch.eye(4, **backend) # Compose deformation on the right right_affine = spatial.affine_lmdiv(orient_mean, orient_out) if not (shape_mean == shape_out and right_affine.all_close(eye)): # the mean space and native space are not the same # we must compose the diffeo with a dense affine transform # we write the diffeo as an identity plus a displacement # (id + disp)(aff) = aff + disp(aff) # ------- # to displacement deformation = deformation - spatial.identity_grid( deformation.shape[:-1], **backend) trf = spatial.affine_grid(right_affine, shape_out) deformation = spatial.grid_pull(utils.movedim(deformation, -1, 0)[None], trf[None], bound='dft', extrapolate=True) deformation = utils.movedim(deformation[0], 0, -1) trf = trf + deformation # add displacement # Compose deformation on the left # the output of the diffeo(right) are mean_space voxels # we must compose on the left with `in\(aff(mean))` # ------- left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in), affine) left_affine = spatial.affine_matmul(left_affine, orient_mean) trf = spatial.affine_matvec(left_affine, trf) return trf