def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid
def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def _warp_image1(image, target, shape=None, affine=None, nonlin=None, backward=False, reslice=False): """Returns the warped image, with channel dimension last""" # build transform aff_right = target aff_left = spatial.affine_inv(image.affine) aff = None if affine: # exp = affine.iexp if backward else affine.exp exp = affine.exp aff = exp(recompute=False, cache_result=True) if backward: aff = spatial.affine_inv(aff) if nonlin: if affine: if affine.position[0].lower() in ('ms' if backward else 'fs'): aff_right = spatial.affine_matmul(aff, aff_right) if affine.position[0].lower() in ('fs' if backward else 'ms'): aff_left = spatial.affine_matmul(aff_left, aff) exp = nonlin.iexp if backward else nonlin.exp phi = exp(recompute=False, cache_result=True) aff_left = spatial.affine_matmul(aff_left, nonlin.affine) aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right) if _almost_identity(aff_right) and nonlin.shape == shape: phi = nonlin.add_identity(phi) else: tmp = spatial.affine_grid(aff_right, shape) phi = regutils.smart_pull_grid(phi, tmp).add_(tmp) del tmp if not _almost_identity(aff_left): phi = spatial.affine_matvec(aff_left, phi) else: # no nonlin: single affine even if position == 'symmetric' if reslice: aff = spatial.affine_matmul(aff, aff_right) aff = spatial.affine_matmul(aff_left, aff) phi = spatial.affine_grid(aff, shape) else: phi = None # warp image if phi is not None: warped = image.pull(phi) else: warped = image.dat # write to disk if len(warped) == 1: warped = warped[0] else: warped = utils.movedim(warped, 0, -1) return warped
def forward(self, source, target, source_affine=None, target_affine=None): """ Parameters ---------- source : (sX, sY, sZ) tensor or str target : (tX, tY, tZ) tensor or str source_affine : (4, 4) tensor, optional target_affine : (4, 4) tensor, optional Returns ------- warped : (tX, tY, tZ) tensor Source warped to target velocity : (vX, vY, vZ, 3) tensor Stationary velocity field affine : (4, 4) tensor, optional Affine of the velocity space """ if self.verbose: print('Preprocessing... ', end='', flush=True) source, source_affine, source_orig, source_affine_orig \ = self.load(source, source_affine) target, target_affine, target_orig, target_affine_orig \ = self.load(target, target_affine) source = spatial.reslice(source, source_affine, target_affine, target.shape) if self.verbose: print('done.', flush=True) print('Registering... ', end='', flush=True) source = source[None, None] target = target[None, None] warped, vel, grid = super().forward(source, target) if self.verbose: print('done.', flush=True) del source, target, warped vel = vel[0] grid = grid[0] grid -= spatial.identity_grid(grid.shape[:-1], dtype=grid.dtype, device=grid.device) right_affine = target_affine.inverse() @ target_affine_orig right_affine = spatial.affine_grid(right_affine, target_orig.shape) grid = spatial.grid_pull(utils.movedim(grid, -1, 0), right_affine, bound='nearest', extrapolate=True) grid = utils.movedim(grid, 0, -1).add_(right_affine) left_affine = source_affine_orig.inverse() @ target_affine grid = spatial.affine_matvec(left_affine, grid) warped = spatial.grid_pull(source_orig, grid) return warped, vel, target_affine
def voxelize_rois(rois, shape, roi_to_vox=None, device=None): """Create a volume of labels from a parametric ROI. Parameters ---------- rois : dict Object returned by `read_asc` shape : sequence[int] roi_to_vox : (d+1, d+1) tensor Returns ------- roi : (*shape) tensor[int] names : list[str] """ out = torch.empty(shape, dtype=torch.long) grid = spatial.identity_grid(shape[:2], device=device) if roi_to_vox is not None: roi_to_vox = roi_to_vox.to(device=device) names = list(rois['regions'].keys()) for l, (name, shapes) in enumerate(rois['regions'].items()): print(name) label = l + 1 for i, shape in enumerate(shapes): print(i + 1, '/', len(shapes), end='\r') vertices = [[p['x'], p['y'], p['z']] for p in shape['points']] vertices = torch.as_tensor(vertices, device=device) if roi_to_vox is not None: vertices = spatial.affine_matvec(roi_to_vox, vertices) z = math.round(vertices[0, 2]).int().item() if not (0 <= z < out.shape[-1]): print('Contour not in FOV. Skipping it...') continue vertices = vertices[:, :2] faces = [(i, i + 1 if i + 1 < len(vertices) else 0) for i in range(len(vertices))] mask = is_inside(grid, vertices, faces).cpu() out[..., z][mask] = label print('') return out, names
def _crop_to_param(aff0, aff, shape): dim = aff0.shape[-1] - 1 shape = shape[:dim] layout0 = spatial.affine_to_layout(aff0) layout = spatial.affine_to_layout(aff) if (layout0 != layout).any(): raise ValueError('Input and Ref do not have the same layout: ' f'{spatial.volume_layout_to_name(layout0)} vs ' f'{spatial.volume_layout_to_name(layout)}.') size = shape layout = None unit = 'vox' center = torch.as_tensor(shape, dtype=torch.float).sub_(1).mul_(0.5) like_aff = spatial.affine_lmdiv(aff0, aff) center = spatial.affine_matvec(like_aff, center) return size, center, unit, layout
def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False): """Convert derivatives wrt warped image in loss space to to derivatives wrt parameters parameters: g (tensor) : gradient wrt warped image h (tensor) : hessian wrt warped image moving (Image) : moving image phi (tensor) : dense (exponentiated) displacement field left (matrix) : left affine right (matrix) : right affine inv (bool) : whether we're in a backward symmetric pass returns: g (tensor) : pushed gradient h (tensor) : pushed hessian gmu (tensor) : rotated spatial gradients """ if inv: g = g.neg_() # build bits of warp dim = phi.shape[-1] fixed_shape = g.shape[-dim:] moving_shape = moving.shape # differentiate wrt δ in: Left o Phi o (Id + δ) o Right # we'll then propagate them through Phi by scaling and squaring if right is not None: right = spatial.affine_grid(right, fixed_shape) g = regutils.smart_push(g, right, shape=self.shape) h = regutils.smart_push(h, right, shape=self.shape) del right phi_left = spatial.identity_grid(self.shape, **utils.backend(phi)) phi_left += phi if left is not None: phi_left = spatial.affine_matvec(left, phi_left) mugrad = moving.pull_grad(phi_left, rotate=False) del phi_left mugrad = _rotate_grad(mugrad, left, phi) return g, h, mugrad
def orient(inp, layout=None, voxel_size=None, center=None, like=None, output=None): """Overwrite the orientation matrix Parameters ---------- inp : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. layout : str or layout-like, default=None (= preserve) Target orientation. voxel_size : [sequence of] float, default=None (= preserve) Target voxel size. center : [sequence of] float, default=None (= preserve) World coordinate of the center of the field of view. like : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. output : str, optional Output filename. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. Returns ------- output : str or (tuple, tensor) If the input is a path, the output path is returned. Else, the new shape and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) dim = f.affine.shape[-1] - 1 inp = (f.shape[:dim], f.affine) if output is None: output = '{dir}{sep}{base}.{layout}{ext}' dir, base, ext = py.fileparts(fname) like_is_file = isinstance(like, str) and like if like_is_file: f = io.volumes.map(like) dim = f.affine.shape[-1] - 1 like = (f.shape[:dim], f.affine) shape, aff0 = inp dim = aff0.shape[-1] - 1 if like: shape_like, aff_like = like else: shape_like, aff_like = (shape, aff0) if voxel_size in (None, 'like') or len(voxel_size) == 0: voxel_size = spatial.voxel_size(aff_like) elif voxel_size == 'self': voxel_size = spatial.voxel_size(aff0) voxel_size = utils.make_vector(voxel_size, dim) if not layout or layout == 'like': layout = spatial.affine_to_layout(aff_like) elif layout == 'self': layout = spatial.affine_to_layout(aff0) layout = spatial.volume_layout(layout) if center in (None, 'like') or len(voxel_size) == 0: center = torch.as_tensor(shape_like, dtype=torch.float) * 0.5 center = spatial.affine_matvec(aff_like, center) elif center == 'self': center = torch.as_tensor(shape, dtype=torch.float) * 0.5 center = spatial.affine_matvec(aff0, center) center = utils.make_vector(center, dim) aff = spatial.affine_default(shape, voxel_size=voxel_size, layout=layout, center=center, dtype=torch.double) if output: dat = io.volumes.load(fname, numpy=True) layout = spatial.volume_layout_to_name(layout) if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, layout=layout) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep, layout=layout) io.volumes.save(dat, output, affine=aff) if is_file: return output else: return shape, aff
def orient(inp, affine=None, layout=None, voxel_size=None, center=None, like=None, output=None, output_transform=None): """Overwrite the orientation matrix Parameters ---------- inp : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. affine : {'self', 'like'} or (4, 4) tensor_like, default='like' Target affine matrix layout : {'self', 'like'} or layout-like, default='like' Target orientation. voxel_size : {'self', 'like'} or [sequence of] float, default='like' Target voxel size. center : {'self', 'like'} or [sequence of] float, default='like' World coordinate of the center of the field of view. like : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. output : str, optional Output filename. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. output_transform : str, optional Filename of output transform. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}_to_{layout}.lta', where `dir` and `base` are the directory and base name of the input file. Returns ------- output : str or (tuple, tensor) If the input is a path, the output path is returned. Else, the new shape and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) dim = f.affine.shape[-1] - 1 inp = (f.shape[:dim], f.affine) if output is None: output = '{dir}{sep}{base}.{layout}{ext}' if output_transform is None: output_transform = '{dir}{sep}{base}_to_{layout}.lta' dir, base, ext = py.fileparts(fname) like_is_file = isinstance(like, str) and like if like_is_file: f = io.volumes.map(like) dim = f.affine.shape[-1] - 1 like = (f.shape[:dim], f.affine) shape, aff0 = inp dim = aff0.shape[-1] - 1 if like: shape_like, aff_like = like else: shape_like, aff_like = (shape, aff0) if voxel_size in (None, 'like') or len(voxel_size) == 0: voxel_size = spatial.voxel_size(aff_like) elif voxel_size == 'self': voxel_size = spatial.voxel_size(aff0) elif voxel_size == 'standard': voxel_size = 1. voxel_size = utils.make_vector(voxel_size, dim) if not layout or layout == 'like': layout = spatial.affine_to_layout(aff_like) elif layout == 'self': layout = spatial.affine_to_layout(aff0) elif layout == 'standard': layout = 'RAS' layout = spatial.volume_layout(layout) if center in (None, 'like') or len(center) == 0: center = (torch.as_tensor(shape_like, dtype=torch.float) - 1) * 0.5 center = spatial.affine_matvec(aff_like, center) elif center == 'self': center = (torch.as_tensor(shape, dtype=torch.float) - 1) * 0.5 center = spatial.affine_matvec(aff0, center) elif center == 'standard': center = 0. center = utils.make_vector(center, dim) if affine in (None, 'like') or len(affine) == 0: affine = aff_like elif affine == 'self': affine = aff0 elif affine == 'standard': affine = torch.eye(dim+1, dim+1) affine = torch.as_tensor(affine, dtype=torch.float) if affine.numel() == dim*(dim+1): affine = spatial.affine_make_rect(affine.reshape(dim, dim+1)) elif affine.numel() == (dim+1)**2: affine = affine.reshape(dim+1, dim+1) else: raise ValueError(f'Input affine should have {dim*(dim+1)} or ' f'{(dim+1)**2} element but got {affine.numel()}.') affine = spatial.affine_modify(affine, shape, voxel_size=voxel_size, layout=layout, center=center) affine = affine.double() if output: dat = io.volumes.load(fname, numpy=True) layout = spatial.volume_layout_to_name(layout) if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, layout=layout) io.volumes.save(dat, output, like=fname, affine=affine) else: output = output.format(sep=os.path.sep, layout=layout) io.volumes.save(dat, output, affine=affine) if output_transform: transform = spatial.affine_rmdiv(affine, aff0) output_transform = output_transform.format( dir=dir or '.', base=base, sep=os.path.sep, layout=layout) io.transforms.savef(transform.cpu(), output_transform, type=2) if is_file: return output else: return shape, affine
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid if options.target: # If target is provided, we build a dense transformation field grid = build_from_target(options.target) oaffine = options.target.affine if options.output_unit[0] == 'v': grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid) grid = grid - spatial.identity_grid(grid.shape[:-1], **utils.backend(grid)) else: grid = grid - spatial.affine_grid( oaffine.to(**utils.backend(grid)), grid.shape[:-1]) io.volumes.savef(grid, options.output.format(ext='.nii.gz'), affine=oaffine) else: if len(options.transformations) > 1: raise RuntimeError('Something weird happened: ' 'multiple transforms and no target') io.transforms.savef(options.transformations[0].affine, options.output.format(ext='.lta'))
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target) oaffine = options.target.affine # 4) Loop across input files opt = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) output = utils.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, **backend) dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: grid = build_from_target(file) oaffine = file.affine mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt) dat = utils.movedim(dat, 0, -1) io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
def get_oriented_slice(image, dim=-1, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample a slice in a RAS system Parameters ---------- image : (..., *shape3) dim : int, default=-1 Index of spatial dimension to sample in the visualization space If RAS: -1 = axial / -2 = coronal / -3 = sagittal index : int, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : (..., *shape2) tensor Slice in the visualisation space. """ # preproc dim if isinstance(dim, str): dim = dim.lower()[0] if dim == 'a': dim = -1 if dim == 'c': dim = -2 if dim == 's': dim = -3 backend = utils.backend(image) # compute default space (mn/mx are in voxels) affine, shape = _get_default_native(affine, image.shape[-3:]) space, mn, mx = _get_default_space(affine, [shape], space, bbox) affine, shape = (affine[0], shape[0]) # compute default cursor (in voxels) if index is None: index = (mx + mn) / 2 else: index = torch.as_tensor(index) index = spatial.affine_matvec(spatial.affine_inv(space), index) # include slice to volume matrix shape = tuple(((mx-mn) + 1).round().int().tolist()) if dim == -1: # axial shift = [[1, 0, 0, - mn[0] + 1], [0, 1, 0, - mn[1] + 1], [0, 0, 1, - index[2]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = shape[:-1] index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1) elif dim == -2: # coronal shift = [[1, 0, 0, - mn[0] + 1], [0, 0, 1, - mn[2] + 1], [0, 1, 0, - index[1]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[0], shape[2]) index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1) elif dim == -3: # sagittal if not transpose_sagittal: shift = [[0, 0, 1, - mn[2] + 1], [0, 1, 0, - mn[1] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[2], shape[1]) index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1) else: shift = [[0, -1, 0, mx[1] + 1], [0, 0, 1, - mn[2] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[1], shape[2]) index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1) else: raise ValueError(f'Unknown dimension {dim}') # sample space = spatial.affine_rmdiv(space, shift) affine = spatial.affine_lmdiv(affine, space) affine = affine.to(**backend) grid = spatial.affine_grid(affine, [*shape, 1]) *channel, s0, s1, s2 = image.shape imshape = (s0, s1, s2) image = image.reshape([1, -1, *imshape]) image = spatial.grid_pull(image, grid[None], interpolation=interpolation, bound='dct2', extrapolate=False) image = image.reshape([*channel, *shape]) return ((image, index, space) if return_index and return_mat else (image, index) if return_index else (image, space) if return_mat else image)
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def do_affine(self, logaff, grad=False, hess=False, in_line_search=False): """Forward pass for updating the affine component (nonlin is not None)""" sumloss = None sumgrad = None sumhess = None # ============================================================== # EXPONENTIATE TRANSFORMS # ============================================================== logaff0 = logaff aff_pos = self.affine.position[0].lower() if any(loss.backward for loss in self.losses): aff0, iaff0, gaff0, igaff0 = \ self.affine.exp2(logaff0, grad=True, cache_result=not in_line_search) phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False) else: iaff0, igaff0, iphi0 = None, None, None aff0, gaff0 = self.affine.exp(logaff0, grad=True, cache_result=not in_line_search) phi0 = self.nonlin.exp(cache_result=True, recompute=False) has_printed = False for loss in self.losses: moving, fixed, factor = loss.moving, loss.fixed, loss.factor if loss.backward: phi00, aff00, gaff00 = iphi0, iaff0, igaff0 else: phi00, aff00, gaff00 = phi0, aff0, gaff0 # ---------------------------------------------------------- # build left and right affine matrices # ---------------------------------------------------------- aff_right, gaff_right = fixed.affine, None if aff_pos in 'fs': gaff_right = gaff00 @ aff_right gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right) aff_right = aff00 @ aff_right aff_right = linalg.lmdiv(self.nonlin.affine, aff_right) aff_left, gaff_left = self.nonlin.affine, None if aff_pos in 'ms': gaff_left = gaff00 @ aff_left gaff_left = linalg.lmdiv(moving.affine, gaff_left) aff_left = aff00 @ aff_left aff_left = linalg.lmdiv(moving.affine, aff_left) # ---------------------------------------------------------- # build full transform # ---------------------------------------------------------- if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape: right = None phi = spatial.add_identity_grid(phi00) else: right = spatial.affine_grid(aff_right, fixed.shape) phi = regutils.smart_pull_grid(phi00, right) phi += right phi_right = phi if _almost_identity(aff_left) and moving.shape == self.nonlin.shape: left = None else: left = spatial.affine_grid(aff_left, self.nonlin.shape) phi = spatial.affine_matvec(aff_left, phi) # ---------------------------------------------------------- # forward pass # ---------------------------------------------------------- warped, mask = moving.pull(phi, mask=True) if fixed.masked: if mask is None: mask = fixed.mask else: mask = mask * fixed.mask do_print = not (has_printed or self.verbose < 3 or in_line_search or loss.backward) if do_print: has_printed = True if moving.previewed: preview = moving.pull(phi, preview=True, dat=False) else: preview = warped init = spatial.affine_lmdiv(moving.affine, fixed.affine) if _almost_identity(init) and moving.shape == fixed.shape: init = moving.dat else: init = spatial.affine_grid(init, fixed.shape) init = moving.pull(init, preview=True, dat=False) self.mov2fix(fixed.dat, init, preview, dim=fixed.dim, title=f'(affine) {self.n_iter:03d}') # ---------------------------------------------------------- # derivatives wrt moving # ---------------------------------------------------------- g = h = None loss_args = (warped, fixed.dat) loss_kwargs = dict(dim=fixed.dim, mask=mask) state = loss.loss.get_state() if not grad and not hess: llx = loss.loss.loss(*loss_args, **loss_kwargs) elif not hess: llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs) else: llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs) del loss_args, loss_kwargs if in_line_search: loss.loss.set_state(state) # ---------------------------------------------------------- # chain rule -> derivatives wrt Lie parameters # ---------------------------------------------------------- def compose_grad(g, h, g_mu, g_aff): """ g, h : gradient/Hessian of loss wrt moving image g_mu : spatial gradients of moving image g_aff : gradient of affine matrix wrt Lie parameters returns g, h: gradient/Hessian of loss wrt Lie parameters """ # Note that `h` can be `None`, but the functions I # use deal with this case correctly. dim = g_mu.shape[-1] g = jg(g_mu, g) h = jhj(g_mu, h) g, h = regutils.affine_grid_backward(g, h) dim2 = dim * (dim + 1) g = g.reshape([*g.shape[:-2], dim2]) g_aff = g_aff[..., :-1, :] g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2]) g = linalg.matvec(g_aff, g) if h is not None: h = h.reshape([*h.shape[:-4], dim2, dim2]) h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2)) # h = h.abs().sum(-1).diag_embed() return g, h if grad or hess: g0, g = g, None h0, h = h, None if aff_pos in 'ms': g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape) h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape) mugrad = moving.pull_grad(left, rotate=False) g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left) g, h = g_left, h_left if aff_pos in 'fs': g_right, h_right = g0, h0 mugrad = moving.pull_grad(phi, rotate=False) jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False) jac = torch.matmul(aff_left[:-1, :-1], jac) mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad) g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right) g = g_right if g is None else g.add_(g_right) h = h_right if h is None else h.add_(h_right) if loss.backward: g = g.neg_() sumgrad = (g.mul_(factor) if sumgrad is None else sumgrad.add_(g, alpha=factor)) if hess: sumhess = (h.mul_(factor) if sumhess is None else sumhess.add_(h, alpha=factor)) sumloss = (llx.mul_(factor) if sumloss is None else sumloss.add_(llx, alpha=factor)) # TODO add regularization term lla = 0 # ============================================================== # VERBOSITY # ============================================================== llx = sumloss.item() sumloss += lla sumloss += self.llv self.loss_value = sumloss.item() if self.verbose and (self.verbose > 1 or not in_line_search): ll = sumloss.item() llv = self.llv if in_line_search: line = '(search) | ' else: line = '(affine) | ' line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}' if not in_line_search: if self.ll_prev is not None: gain = self.ll_prev - ll # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) line += f' | {gain:12.6g}' self.all_ll.append(ll) self.ll_prev = ll self.ll_max = max(self.ll_max, ll) self.n_iter += 1 print(line, end='\r') # ============================================================== # RETURN # ============================================================== out = [sumloss] if grad: out.append(sumgrad) if hess: out.append(sumhess) return tuple(out) if len(out) > 1 else out[0]
def ffd(source, target, grid_shape=10, group='SE', image_loss=None, def_loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """FFD (= cubic spline) registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dft') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: affine_parameters = torch.as_tensor(init, **backend).clone().detach() affine_parameters = affine_parameters.reshape([batch, nb_prm]) else: affine_parameters = torch.zeros([batch, nb_prm], **backend) affine_parameters = nn.Parameter(affine_parameters, requires_grad=optim_affine) grid_shape = core.pyutils.make_list(grid_shape, dim) grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend) grid_parameters = nn.Parameter(grid_parameters, requires_grad=True) def pull(q, grid): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved def exp(prm): disp = spatial.resize_grid(prm, type='displacement', shape=target.shape[2:], interpolation=3, bound='dft') grid = disp + spatial.identity_grid(target.shape[2:], **backend) return disp, grid # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(def_loss): def_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if def_loss is None else def_loss def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{ 'params': affine_parameters, 'lr': lr[1] }, { 'params': grid_parameters, 'lr': lr[0] }] if optim_affine else [grid_parameters] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # with torch.no_grad(): # disp, grid = exp(grid_parameters) # moved = pull(affine_parameters, grid) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(max_iter): loss_val0 = loss_val zero_grad_([affine_parameters, grid_parameters]) disp, grid = exp(grid_parameters) moved = pull(affine_parameters, grid) loss_val = image_loss(moved, target) + def_loss(disp[0]) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: # print(affine_parameters) # plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu()) # plt.show() loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(affine_parameters, grid) aff = core.linalg.expm(affine_parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return affine_parameters, aff, grid_parameters, moved
def _index_from_cursor(self, x, y, image, n_ax): p = utils.as_tensor([x, y, 0]) mat = image._mats[n_ax] self.index = spatial.affine_matvec(mat, p)
def diffeo(source, target, group='SE', image_loss=None, vel_loss=None, pull=None, preproc=False, max_iter=1000, device=None, origin='center', init=None, lr=1e-4, optim_affine=True, scheduler=ReduceLROnPlateau): """ Parameters ---------- source : path or tensor or (tensor, affine) target : path or tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' image_loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr: float, default=1e-4 optim_affine : bool, default=True Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (D+1, D+1) tensor Initial velocity of the diffeomorphic transform. The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)` moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] source = rescale(source) targe = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) if scheduler is not None: scheduler = scheduler(optim, cooldown=5) # Optim loop loss_val = core.constants.inf loss_avg = 0 for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) loss_val.backward() optim.step() with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 if scheduler is not None: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_avg) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, velocity, moved
def diffeo(source, target, group='SE', origin='center', image_loss=None, vel_loss=None, pull=None, optim_affine=True, max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None): """Diffeomorphic registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) The source (moving) image, with shape (batch, channel, *spatial). target : tensor or (tensor, affine) The target (fixed) image, with shape (batch, channel, *spatial). group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid' Affine sub-group to optimize. origin : {'native', 'center'}, default='center' Whether to rotate about the origin of the world-space ('native') or the center of the target field-of-view ('center'). When the origin of the world-space is far off (say you are registering smaller blocks cropped from a larger MRI), it can be beneficiary to rotate about the center of the FOV. image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss() A loss function that takestwo inputs of shape (batch, channel, *spatial). vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss() Either a factor to muultiply the bending loss with or a loss function that takes two inputs of shape (batch, channel, *spatial). pull : dict interpolation : int, default=1 Interpolation order bound : bound_like, default='dct2' Boundary condition extrapolate : bool, default=False Extrapolate out-of-bound data using the boundary conditions. max_iter : int, default=1000 Maximum number of iterations lr : float, default=0.1 Initial learning rate. min_lr : float, default=1e-7 Minimum learning rate. The optimization is stopped once this learning rate is reached. device : {'cpu', 'cuda', 'cuda:<id>'}, optional Backend to use init : ([batch], nb_prm) tensor_like, default=0 Initial guess for the affine parameters. Returns ------- q : (batch, nb_prm) tensor Parameters aff : (batch, D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. vel : (batch, *shape, D) tensor Initial velocity moved : tensor Source image moved to target space. """ group = affine_group_converter(group) pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] dim = source.dim() - 2 # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff = target_aff.clone() source_aff = source_aff.clone() target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=optim_affine) velocity = torch.zeros([batch, *target.shape[2:], dim], **backend) velocity = nn.Parameter(velocity, requires_grad=True) def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if not callable(image_loss): image_loss_fn = nni.MutualInfoLoss() factor = 1. if image_loss is None else image_loss image_loss = lambda x, y: factor * image_loss_fn(x, y) if not callable(vel_loss): vel_loss_fn = nni.BendingLoss(bound='dft') factor = 1. if vel_loss is None else vel_loss vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x)) lr = core.utils.make_list(lr, 2) min_lr = core.utils.make_list(min_lr, 2) opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \ if optim_affine else [velocity] optim = torch.optim.Adam(opt_prm, lr=lr[0]) scheduler = ReduceLROnPlateau(optim) def forward(): moved = pull(parameters, velocity) loss_val = image_loss(moved, target) + vel_loss(velocity) return loss_val # Optim loop loss_avg = 0 for n_iter in range(1, max_iter + 1): optim.zero_grad(set_to_none=True) loss_val = forward() loss_val.backward() optim.step(forward) with torch.no_grad(): loss_avg += loss_val if n_iter % 10 == 0: loss_avg /= 10 scheduler.step(loss_avg) print('{:4d} {:12.6f} | lr={:g} ' .format(n_iter, loss_avg.item(), optim.param_groups[0]['lr']), end='\r') loss_avg = 0 if (optim.param_groups[0]['lr'] < min_lr[0] and (len(optim.param_groups) == 1 or optim.param_groups[1]['lr'] < min_lr[1])): print('\nConverged.') break print('') with torch.no_grad(): moved = pull(parameters, velocity) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() return (parameters.detach(), aff.detach(), velocity.detach(), moved.detach())
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None, interpolation=1, bound='dct2', extrapolate=False, device=None, verbose=True): """Apply the linear and non-linear components of the transform and reslice the image to the target space. Notes ----- .. The shape and general orientation of the moving image is kept untouched. .. The linear transform is composed with the original orientation matrix. .. The non-linear component is "wrapped" in the input space, where it is applied. .. This function writes a new file (it does not modify the input files in place). Parameters ---------- moving : ImageFile An object describing a moving image. fname : list of str Output filename for each input file of the moving image (since images can be encoded over multiple volumes) like : ImageFile An object describing the target space inv : bool, default=False True if we are warping the fixed image to the moving space. In the case, `moving` should be a `FixedImageFile` and `like` a `MovingImageFile`. Else it should be a `MovingImageFile` and `'like` a `FixedImageFile`. lin : (4, 4) tensor, optional Linear (or rather affine) transformation nonlin : dict, optional Non-linear displacement field, with keys: disp : (..., 3) tensor Displacement field (in voxels) affine : (4, 4) tensor Orientation matrix of the displacement field interpolation : int, default=1 bound : str, default='dct2' extrapolate : bool, default = False device : default='cpu' """ nonlin = nonlin or dict(disp=None, affine=None) prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate) moving_affine = moving.affine.to(device) fixed_affine = like.affine.to(device) if inv: # affine-corrected fixed space if lin is not None: fix2lin = affine_matmul(lin, fixed_affine) else: fix2lin = fixed_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param to moving nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(moving_affine, fix2lin) g = affine_grid(g, like.shape) else: # affine-corrected moving space if lin is not None: mov2nlin = affine_matmul(lin, moving_affine) else: mov2nlin = moving_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param voxels to moving voxels (warps moving to fixed) nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(mov2nlin, fixed_affine) g = affine_grid(g, like.shape) if moving.type == 'labels': prm['interpolation'] = 0 for file, ofname in zip(moving.files, fname): if verbose: print(f'Resliced: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, device=device) dat = dat.reshape([*file.shape, file.channels]) if g is not None: dat = utils.movedim(dat, -1, 0) dat = pull(dat, g, **prm) dat = utils.movedim(dat, 0, -1) io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss
def load_transforms(s): """Initialize transforms""" device = torch.device(s.device) def reshape3d(dat, channels=None, dim=3): """Reshape as (*spatial) or (C, *spatial) or (*spatial, C). `channels` should be in ('first', 'last', None). """ while len(dat.shape) > dim: if dat.shape[-1] == 1: dat = dat[..., 0] continue elif dat.shape[dim] == 1: dat = dat[:, :, :, 0, ...] continue else: break if len(dat.shape) > dim + bool(channels): raise ValueError('Too many channel dimensions') if channels and len(dat.shape) == dim: dat = dat[..., None] if channels == 'first': dat = utils.movedim(dat, -1, 0) return dat # compute mean space # it is used to define the space of the nonlinear transform, but # also to shift the center of rotation of the linear transform. all_affines = [] all_shapes = [] all_affines_fixed = [] all_shapes_fixed = [] for loss in s.losses: if isinstance(loss, struct.NoLoss): continue if getattr(loss, 'exclude', False): continue all_shapes_fixed.append(loss.fixed.shape) all_affines_fixed.append(loss.fixed.affine) all_shapes.append(loss.fixed.shape) all_affines.append(loss.fixed.affine) all_shapes.append(loss.moving.shape) all_affines.append(loss.moving.affine) affine0, shape0 = mean_space(all_affines, all_shapes, pad=s.pad, pad_unit=s.pad_unit) affinef, shapef = mean_space(all_affines_fixed, all_shapes_fixed, pad=s.pad, pad_unit=s.pad_unit) backend = dict(dtype=affine0.dtype, device=affine0.device) for trf in s.transformations: for reg in trf.losses: if isinstance(reg.factor, (list, tuple)): reg.factor = [f * trf.factor for f in reg.factor] else: reg.factor = reg.factor * trf.factor if isinstance(trf, struct.Linear): # Affine if isinstance(trf.init, str): trf.dat = io.transforms.loadf(trf.init, dtype=torch.float32, device=device) else: trf.dat = torch.zeros(trf.nb_prm(3), dtype=torch.float32, device=device) if trf.shift: shift = torch.as_tensor(shapef, **backend) * 0.5 trf.shift = -spatial.affine_matvec(affinef, shift) else: trf.shift = 0. else: affine, shape = (affine0, shape0) trf.pyramid = list(sorted(trf.pyramid)) max_level = max(trf.pyramid) factor = 2**(max_level-1) affine, shape = affine_resize(affine, shape, 1/factor) # FFD/Diffeo if isinstance(trf.init, str): f = io.volumes.map(trf.init) trf.dat = reshape3d(f.loadf(dtype=torch.float32, device=device), 'last') if len(trf.dat) != trf.dim: raise ValueError('Field should have 3 channels') factor = [int(s//g) for g, s in zip(trf.shape[:-1], shape)] trf.affine, trf.shape = affine_resize(trf.affine, trf.shape[:-1], factor) else: trf.dat = torch.zeros([*shape, trf.dim], dtype=torch.float32, device=device) trf.affine = affine trf.shape = shape
def crop(inp, size=None, center=None, space='vx', like=None, output=None, transform=None): """Crop a ND volume, while preserving the orientation matrices. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. size : [sequence of] int, optional Size of the patch to extract. Its unit and axes are defined by `units` and `layout`. center : [sequence of] int, optional Coordinate of the center of the patch. Its unit and axes are defined by `units` and `layout`. By default, the center of the FOV is used. space : [sequence of] {'vox', 'ras'}, default='vox' The space in which the `size` and `center` parameters are expressed. like : str or (tensor, tensor), optional Reference patch. Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. transform : [sequence of] str, optional Input or output filename(s) of the corresponding transforms. Not written by default. If a transform is provided and all other parameters (i.e., `size` and `like`) are None, the transform is considered as an input transform to apply. Returns ------- output : list[str or (tensor, tensor)] If the input is a path, the output paths are returned. Else, the unstacked data and orientation matrices are returned. """ dir = '' base = '' ext = '' fname = None transform_in = False # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.data(numpy=True), f.affine) if output is None: output = '{dir}{sep}{base}.crop{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp dim = aff0.shape[-1] - 1 shape0 = dat.shape[:dim] if size and like: raise ValueError('Cannot use both `size` and `like`.') # --- Open reference and compute size/center --- if like: like_is_file = isinstance(like, str) if like_is_file: f = io.volumes.map(like) like = (f.shape, f.affine) like_shape, like_aff = like if torch.is_tensor(like_shape): like_shape = like_shape.shape size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) # --- Open transformation file and compute size/center --- elif not size: if not transform: raise ValueError('At least one of size/like/transform must ' 'be provided') transform_in = True t = io.transforms.map(transform) if not isinstance(t, io.transforms.LinearTransformArray): raise TypeError('Expected an LTA file') like_aff, like_shape = t.destination_space() size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) # --- use center of the FOV --- if not torch.is_tensor(center) and not center: center = torch.as_tensor(shape0[:dim], dtype=torch.float) * 0.5 # --- convert size/center to voxels --- size = utils.make_vector(size, dim, dtype=torch.long) center = utils.make_vector(center, dim, dtype=torch.float) space_size, space_center = py.make_list(space, 2) if space_center.lower() == 'ras': center = spatial.affine_matvec(spatial.affine_inv(aff0), center) if space_size.lower() == 'ras': perm = spatial.affine_to_layout(aff0)[:, 0] size = size[perm.long()] size = size / spatial.voxel_size(aff0) # --- compute first/last --- center = center.float() size = size.ceil().long() first = (center - size.float() / 2).round().long() last = (first + size).tolist() first = [max(f, 0) for f in first.tolist()] last = [min(l, s) for l, s in zip(last, shape0[:dim])] verb = 'Cropping patch [' verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)]) verb += f'] from volume with shape {shape0[:dim]}' print(verb) slicer = tuple(slice(f, l) for f, l in zip(first, last)) # --- do crop --- dat = dat[slicer] aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer) shape = dat.shape[:dim] if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff) if transform and not transform_in: if is_file: transform = transform.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) else: transform = transform.format(sep=os.path.sep) trf = io.transforms.LinearTransformArray(transform, 'w') trf.set_source_space(aff0, shape0) trf.set_destination_space(aff, shape) trf.set_metadata({ 'src': { 'filename': fname }, 'dst': { 'filename': output }, 'type': 1 }) # RAS_TO_RAS trf.set_fdata(torch.eye(4)) trf.save() if is_file: return output else: return dat, aff
def vexp(inp, type='displacement', unit='voxel', inverse=False, bound='dft', steps=8, device=None, output=None): """Exponentiate a stationary velocity fields. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. type : {'displacement', 'transformation'}, default='displacement' Whether to return a displacement field (coord-to-shift) or a transformation field (coord-to-coord). unit : {'voxel', 'mm'}, default='voxel' Whether to return displacement/coordinates in voxel or in mm. If mm, the input orientation matrix is used to convert voxels to mm. inverse : bool, default=False Whether to return the inverse field. bound : str, default='dft' Boundary conditions. steps : int, default=8 Number of scaling and squaring steps. device : str, optional Device to use. output : str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. Returns ------- output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the output tensor and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = None # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.fdata(device=device), f.affine) if output is None: output = '{dir}{sep}{base}.vexp{ext}' dir, base, ext = py.fileparts(fname) else: if torch.is_tensor(inp): inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3])) dat, aff = inp dat = dat.to(device=device) aff = aff.to(device=device) # exponentiate dat = spatial.exp(dat[None], inverse=inverse, steps=steps, bound=bound, inplace=True, displacement=(type.lower()[0] == 'd'))[0] if unit == 'mm': # if type.lower()[0] == 'd': # vx = spatial.voxel_size(aff) # dat *= vx # else: dat = spatial.affine_matvec(aff, dat) if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff.cpu()) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff.cpu()) if is_file: return output else: return dat, aff
def __call__(self, logaff, grad=False, hess=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. # select correct gradient mode if grad: logaff.requires_grad_() if logaff.grad is not None: logaff.grad.zero_() if grad and not torch.is_grad_enabled(): with torch.enable_grad(): return self(logaff, grad, in_line_search=in_line_search) elif not grad and torch.is_grad_enabled(): with torch.no_grad(): return self(logaff, grad, in_line_search=in_line_search) pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True, # **utils.backend(self.fixed)) # fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff)) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space llx = self.loss.loss(warped, fixed) del warped # print objective lll = llx llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [lll] if grad is not False: lll.backward() grad = logaff.grad.clone() out.append(grad) logaff.requires_grad_(False) return tuple(out) if len(out) > 1 else out[0]
def compose(self, orient_in, deformation, orient_mean, affine=None, orient_out=None, shape_out=None): """Composes a deformation defined in a mean space to an image space. Parameters ---------- orient_in : (4, 4) tensor Orientation of the input image deformation : (*shape_mean, 3) tensor Random deformation orient_mean : (4, 4) tensor Orientation of the mean space (where the deformation is) affine : (4, 4) tensor, default=identity Random affine orient_out : (4, 4) tensor, default=orient_in Orientation of the output image shape_out : sequence[int], default=shape_mean Shape of the output image Returns ------- grid : (*shape_out, 3) Voxel-to-voxel transform """ if orient_out is None: orient_out = orient_in if shape_out is None: shape_out = deformation.shape[:-1] if affine is None: affine = torch.eye(4, 4, device=orient_in.device, dtype=orient_in.dtype) shape_mean = deformation.shape[:-1] orient_in, affine, deformation, orient_mean, orient_out \ = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out) backend = utils.backend(deformation) eye = torch.eye(4, **backend) # Compose deformation on the right right_affine = spatial.affine_lmdiv(orient_mean, orient_out) if not (shape_mean == shape_out and right_affine.all_close(eye)): # the mean space and native space are not the same # we must compose the diffeo with a dense affine transform # we write the diffeo as an identity plus a displacement # (id + disp)(aff) = aff + disp(aff) # ------- # to displacement deformation = deformation - spatial.identity_grid( deformation.shape[:-1], **backend) trf = spatial.affine_grid(right_affine, shape_out) deformation = spatial.grid_pull(utils.movedim(deformation, -1, 0)[None], trf[None], bound='dft', extrapolate=True) deformation = utils.movedim(deformation[0], 0, -1) trf = trf + deformation # add displacement # Compose deformation on the left # the output of the diffeo(right) are mean_space voxels # we must compose on the left with `in\(aff(mean))` # ------- left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in), affine) left_affine = spatial.affine_matmul(left_affine, orient_mean) trf = spatial.affine_matvec(left_affine, trf) return trf
def crop(inp, size=None, center=None, space='vx', like=None, bbox=False, output=None, transform=None): """Crop a ND volume, while preserving the orientation matrices. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. size : [sequence of] int, optional Size of the patch to extract. Its unit and axes are defined by `units` and `layout`. center : [sequence of] int, optional Coordinate of the center of the patch. Its unit and axes are defined by `units` and `layout`. By default, the center of the FOV is used. space : [sequence of] {'vox', 'ras'}, default='vox' The space in which the `size` and `center` parameters are expressed. bbox : bool or float, default=False Crop at the bounding box of `inp > threshold`. If `bbox` is a float, it is the threshold to use. If `bbox` is `True`, the threshold is 0. like : str or (tensor, tensor), optional Reference patch. Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. transform : [sequence of] str, optional Input or output filename(s) of the corresponding transforms. Not written by default. If a transform is provided and all other parameters (i.e., `size` and `like`) are None, the transform is considered as an input transform to apply. Returns ------- output : list[str or (tensor, tensor)] If the input is a path, the output paths are returned. Else, the unstacked data and orientation matrices are returned. """ dir = '' base = '' ext = '' fname = None transform_in = False use_bbox = bool(bbox or isinstance(bbox, float)) # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.data(numpy=True) if use_bbox else f, f.affine) if output is None: output = '{dir}{sep}{base}.crop{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp dim = aff0.shape[-1] - 1 shape0 = dat.shape[:dim] layout0 = spatial.affine_to_layout(aff0) # save input space in case we reorient later aff00 = aff0 shape00 = shape0 if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1: raise ValueError('Can only use one of `size`, `like` and `bbox`.') # --- Open reference and compute size/center --- if like: like_is_file = isinstance(like, str) if like_is_file: f = io.volumes.map(like) like = (f.shape, f.affine) like_shape, like_aff = like like_layout = spatial.affine_to_layout(like_aff) if (layout0 != like_layout).any(): aff0, dat = spatial.affine_reorient(aff0, dat, like_layout) shape0 = dat.shape[:dim] if torch.is_tensor(like_shape): like_shape = like_shape.shape size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) space = 'vox' elif bbox or isinstance(bbox, float): if bbox is True: bbox = 0. mask = torch.as_tensor(dat > bbox) while mask.dim() > 3: mask = mask.any(dim=-1) mins = [] maxs = [] for d in range(dim): n = mask.shape[d] idx = utils.movedim(mask, d, 0).reshape([n, -1 ]).any(-1).nonzero(as_tuple=False) mins.append(idx.min()) maxs.append(idx.max()) mins = utils.as_tensor(mins) maxs = utils.as_tensor(maxs) size = maxs + 1 - mins center = (maxs + 1 + mins).float() / 2 space = 'vox' del mask # --- Open transformation file and compute size/center --- elif not size: if not transform: raise ValueError('At least one of size/like/transform must ' 'be provided') transform_in = True t = io.transforms.map(transform) if not isinstance(t, io.transforms.LinearTransformArray): raise TypeError('Expected an LTA file') like_aff, like_shape = t.destination_space() size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) # --- use center of the FOV --- if not torch.is_tensor(center) and not center: center = torch.as_tensor(shape0[:dim], dtype=torch.float) center = center.sub_(1).mul_(0.5) # --- convert size/center to voxels --- size = utils.make_vector(size, dim, dtype=torch.long) center = utils.make_vector(center, dim, dtype=torch.float) space_size, space_center = py.make_list(space, 2) if space_center.lower() == 'ras': center = spatial.affine_matvec(spatial.affine_inv(aff0), center) if space_size.lower() == 'ras': perm = spatial.affine_to_layout(aff0)[:, 0] size = size[perm.long()] size = size / spatial.voxel_size(aff0) # --- compute first/last --- center = center.float() size = (size.ceil() if size.dtype.is_floating_point else size).long() first = center - size.float().sub_(1).mul_(0.5) first = first.round().long() last = (first + size).tolist() first = [max(f, 0) for f in first.tolist()] last = [min(l, s) for l, s in zip(last, shape0[:dim])] verb = 'Cropping patch [' verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)]) verb += f'] from volume with shape {shape0[:dim]}' print(verb) slicer = tuple(slice(f, l) for f, l in zip(first, last)) # --- do crop --- if use_bbox: dat = dat.numpy() dat = dat[slicer] if not torch.is_tensor(dat): dat = dat.data(numpy=True) aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer) shape = dat.shape[:dim] if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff) if transform and not transform_in: if is_file: transform = transform.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) else: transform = transform.format(sep=os.path.sep) trf = io.transforms.LinearTransformArray(transform, 'w') trf.set_source_space(aff00, shape00) trf.set_destination_space(aff, shape) trf.set_metadata({ 'src': { 'filename': fname }, 'dst': { 'filename': output }, 'type': 1 }) # RAS_TO_RAS trf.set_fdata(torch.eye(4)) trf.save() if is_file: return output else: return dat, aff
def info(inp, meta=None, stat=False): """Print information on a volume. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. meta : sequence of str List of fields to print. By default, a list of common fields is used. stat : bool, default=False Compute intensity statistics """ meta = meta or [] metadata = {} is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) if stat: inp = (f.fdata(), f.affine) else: inp = (f.shape, f.affine) metadata = f.metadata(meta) metadata['dtype'] = f.dtype dat, aff = inp if not is_file: metadata['dtype'] = dat.dtype if torch.is_tensor(dat): shape = dat.shape else: shape = dat pad = max([0] + [len(m) for m in metadata.keys()]) if not meta: more_fields = ['shape', 'layout', 'filename'] pad = max(pad, max(len(f) for f in more_fields)) title = lambda tag: ('{tag:' + str(pad) + 's}').format(tag=tag) if not meta: if is_file: print(f'{title("filename")} : {fname}') print(f'{title("shape")} : {tuple(shape)}') layout = spatial.affine_to_layout(aff) layout = spatial.volume_layout_to_name(layout) print(f'{title("layout")} : {layout}') center = torch.as_tensor(shape[:3], dtype=torch.float) / 2 center = spatial.affine_matvec(aff, center) print(f'{title("center")} : {tuple(center.tolist())} mm (RAS)') if stat and torch.is_tensor(dat): chandim = list(range(3, dat.ndim)) if not chandim: vmin = dat.min().tolist() vmax = dat.max().tolist() vmean = dat.mean().tolist() else: dat1 = dat.reshape([-1, *chandim]) vmin = dat1.min(dim=0).values.tolist() vmax = dat1.max(dim=0).values.tolist() vmean = dat1.mean(dim=0).tolist() print(f'{title("min")} : {vmin}') print(f'{title("max")} : {vmax}') print(f'{title("mean")} : {vmean}') for key, value in metadata.items(): if value is None and not meta: continue if torch.is_tensor(value): value = str(value.numpy()) value = value.split('\n') value = ('\n' + ' ' * (pad + 3)).join(value) print(f'{title(key)} : {value}')
def _compute_cost(q, grid0, dat_fix, mat_fix, dat, mat, mov, cost_fun, B, mx_int, fwhm, return_res=False): """Compute registration cost function. Parameters ---------- q : (N, Nq) tensor_like Lie algebra of affine registration fit. grid0 : (X1, Y1, Z1) tensor_like Sub-sampled image data's resampling grid. dat_fix : (X1, Y1, Z1) tensor_like Fixed image data. mat_fix : (4, 4) tensor_like Fixed affine matrix. dat : [N,] tensor_like List of input images. mat : [N,] tensor_like List of affine matrices. mov : [N,] int Indices of moving images. cost_fun : str Cost function to compute (see run_affine_reg). B : (Nq, N, N) tensor_like Affine basis. mx_int : int This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=511 -> H.shape = (512, 512)). fwhm : float Full-width at half max of Gaussian kernel, for smoothing histogram. return_res : bool, default=False Return registration results for plotting. Returns ---------- c : float Cost of aligning images with current estimate of q. If optimiser='powell', array_like, else tensor_like. res : tensor_like Registration results, for visualisation (only if return_res=True). """ # Init device = grid0.device q = q.flatten() was_numpy = False if isinstance(q, np.ndarray): was_numpy = True q = torch.from_numpy(q).to(device) # To torch tensor dm_fix = dat_fix.shape Nq = B.shape[0] N = torch.tensor(len(dat), device=device, dtype=torch.float32) # For modulating NJTV cost if cost_fun in _costs_edge: jtv = dat_fix.clone() if cost_fun == 'njtv': njtv = -dat_fix.sqrt() for i, m in enumerate(mov): # Loop over moving images # Get affine matrix mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B) # Compose matrices M = mat_a.mm(mat_fix).solve(mat[m])[0].type( torch.float32) # mat_mov\mat_a*mat_fix # Transform fixed grid grid = affine_matvec(M, grid0) # Resample to fixed grid dat_new = grid_pull(dat[m][None, None, ...], grid[None, ...], bound='dft', extrapolate=False, interpolation=1)[0, 0, ...] if cost_fun in _costs_edge: jtv += dat_new if cost_fun == 'njtv': njtv -= dat_new.sqrt() # Compute the cost function res = None if cost_fun in _costs_hist: # Histogram based costs # ---------- # Compute joint histogram # OBS: This function expects both images to have the same max and min intesities, # this is ensured by the _data_loader() function. H = _hist_2d(dat_fix, dat_new, mx_int, fwhm) res = H # Get probabilities pxy = H / H.sum() px = pxy.sum(dim=0, keepdim=True) py = pxy.sum(dim=1, keepdim=True) # Compute cost if cost_fun == 'mi': # Mutual information mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) c = -mi elif cost_fun == 'ecc': # Entropy Correlation Coefficient # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997). # "Multimodality image registration by maximisation of mutual # information". IEEE Transactions on Medical Imaging 16(2):187-198 mi = torch.sum(pxy * torch.log2(pxy / py.mm(px))) ecc = -2 * mi / (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) c = -ecc elif cost_fun == 'nmi': # Normalised Mutual Information # Studholme, Hill & Hawkes (1998). # "A normalized entropy measure of 3-D medical image alignment". # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143. nmi = (torch.sum(px * px.log2()) + torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2()) c = -nmi elif cost_fun == 'ncc': # Normalised Cross Correlation i = torch.arange(1, pxy.shape[0] + 1, device=device, dtype=torch.float32) j = torch.arange(1, pxy.shape[1] + 1, device=device, dtype=torch.float32) m1 = torch.sum(py * i[..., None]) m2 = torch.sum(px * j[None, ...]) sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2)) sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2)) i, j = torch.meshgrid(i - m1, j - m2) ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2) c = -ncc elif cost_fun in _costs_edge: # Normalised Joint Total Variation # M Brudfors, Y Balbastre, J Ashburner (2020). # "Groupwise Multimodal Image Registration using Joint Total Variation". # in MIUA 2020. jtv.sqrt_() if cost_fun == 'njtv': njtv += torch.sqrt(N) * jtv res = njtv c = torch.sum(njtv) else: res = jtv c = torch.sum(jtv) # _ = show_slices(res, fig_num=1, cmap='coolwarm') # Can be uncommented for testing if was_numpy: # Back to numpy array c = c.cpu().numpy() if return_res: return c, res else: return c
def affine(source, target, group='SE', loss=None, pull=None, preproc=True, max_iter=1000, device=None, origin='center', init=None, lr=0.1, scheduler=ReduceLROnPlateau): """Affine registration Note ---- .. Tensors must have shape (batch, channel, *spatial) .. Composite losses (e.g., computed on both intensity and categorical images) can be obtained by stacking all types of inputs across the channel dimension. The loss function is then responsible for unstacking the tensor and computing the appropriate losses. The drawback of this approach is that all inputs must share the same lattice and orientation matrix, as well as the same interpolation order. The advantage is that it simplifies the signature of this function. Parameters ---------- source : tensor or (tensor, affine) target : tensor or (tensor, affine) group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE' loss : Loss, default=MutualInfoLoss() pull : dict interpolation : int, default=1 bound : bound_like, default='dct2' extrapolate : bool, default=False preproc : bool, default=True max_iter : int, default=1000 device : device, optional origin : {'native', 'center'}, default='center' init : tensor_like, default=0 lr : float, default=0.1 scheduler : Scheduler, default=ReduceLROnPlateau Returns ------- q : tensor Parameters aff : (D+1, D+1) tensor Affine transformation matrix. The source affine matrix can be "corrected" by left-multiplying it with `aff`. moved : tensor Source image moved to target space. """ pull = pull or dict() pull['interpolation'] = pull.get('interpolation', 'linear') pull['bound'] = pull.get('bound', 'dct2') pull['extrapolate'] = pull.get('extrapolate', False) pull_opt = pull # prepare all data tensors ((source, source_aff), (target, target_aff)) = prepare([source, target], device) backend = get_backend(source) batch = source.shape[0] src_channels = source.shape[1] trg_channels = target.shape[1] dim = source.dim() - 2 # Rescale to [0, 1] if preproc: source = rescale(source) target = rescale(target) # Shift origin if origin == 'center': shift = torch.as_tensor(target.shape, **backend) / 2 shift = -spatial.affine_matvec(target_aff, shift) target_aff[..., :-1, -1] += shift source_aff[..., :-1, -1] += shift # Prepare affine utils + Initialize parameters basis = spatial.affine_basis(group, dim, **backend) nb_prm = spatial.affine_basis_size(group, dim) if init is not None: parameters = torch.as_tensor(init, **backend).clone().detach() parameters = parameters.reshape([batch, nb_prm]) else: parameters = torch.zeros([batch, nb_prm], **backend) parameters = nn.Parameter(parameters, requires_grad=True) identity = spatial.identity_grid(target.shape[2:], **backend) def pull(q): aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None)) grid = spatial.affine_matvec(aff[expd], identity) moved = spatial.grid_pull(source, grid, **pull_opt) return moved # Prepare loss and optimizer if loss is None: loss_fn = nni.MutualInfoLoss() loss = lambda x, y: loss_fn(x, y) optim = torch.optim.Adam([parameters], lr=lr) if scheduler is not None: scheduler = scheduler(optim) # Optim loop loss_val = core.constants.inf for n_iter in range(1, max_iter + 1): loss_val0 = loss_val optim.zero_grad(set_to_none=True) moved = pull(parameters) loss_val = loss(moved, target) loss_val.backward() optim.step() if scheduler is not None and n_iter % 10 == 0: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_val) else: scheduler.step() with torch.no_grad(): if n_iter % 10 == 0: print('{:4d} {:12.6f} | lr={:g}'.format( n_iter, loss_val.item(), optim.param_groups[0]['lr']), end='\r') print('') with torch.no_grad(): moved = pull(parameters) aff = core.linalg.expm(parameters, basis) if origin == 'center': aff[..., :-1, -1] -= shift shift = core.linalg.matvec(aff[..., :-1, :-1], shift) aff[..., :-1, -1] += shift aff = aff.inverse() aff.requires_grad_(False) return parameters, aff, moved