def pull(q, vel): grid = spatial.exp(vel) aff = core.linalg.expm(q, basis) aff = spatial.affine_matmul(aff, target_aff) aff = spatial.affine_lmdiv(source_aff, aff) grid = spatial.affine_matvec(aff, grid) moved = spatial.grid_pull(source, grid, **pull_opt) return moved
def forward(self, velocity, **kwargs): """ Parameters ---------- velocity (tensor) : velocity field with shape (batch, *spatial, dim). **kwargs : all parameters of the module can be overridden at call time. Returns ------- forward (tensor, if `forward is True`) : forward displacement (if `displacement is True`) or transformation (if `displacement is False`) field, with shape (batch, *spatial, dim) inverse (tensor, if `inverse is True`) : forward displacement (if `displacement is True`) or transformation (if `displacement is False`) field, with shape (batch, *spatial, dim) """ fwd = kwargs.get('fwd', self.forward) inv = kwargs.get('inverse', self.inv) opt = { 'steps': kwargs.get('steps', self.steps), 'interpolation': kwargs.get('interpolation', self.interpolation), 'bound': kwargs.get('bound', self.bound), 'displacement': kwargs.get('displacement', self.displacement), 'inplace': False, # kwargs.get('inplace', self.inplace) } output = [] if fwd: y = spatial.exp(velocity, inverse=False, **opt) output.append(y) if inv: iy = spatial.exp(velocity, inverse=True, **opt) output.append(iy) return output if len(output) > 1 else \ output[0] if len(output) == 1 else \ None
def forward(self, velocity, fwd=None, inv=None): """ Parameters ---------- velocity :(batch, *spatial, dim) tensor Stationary velocity field. fwd : bool, default=self.fwd inv : bool, default=self.inv Returns ------- forward : (batch, *spatial, dim) tensor, if `fwd is True` Forward displacement (if `displacement is True`) or transformation (if `displacement is False`) field. inverse : (batch, *spatial, dim) tensor, if `inv is True` Inverse displacement (if `displacement is True`) or transformation (if `displacement is False`) field. """ fwd = fwd if fwd is not None else self.fwd inv = inv if inv is not None else self.inv opt = dict(steps=self.steps, interpolation=self.interpolation, bound=self.bound, displacement=self.displacement, anagrad=self.anagrad) output = [] if fwd: y = spatial.exp(velocity, inverse=False, **opt) output.append(y) if inv: iy = spatial.exp(velocity, inverse=True, **opt) output.append(iy) return output if len(output) > 1 else \ output[0] if len(output) == 1 else \ None
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid if options.target: # If target is provided, we build a dense transformation field grid = build_from_target(options.target) oaffine = options.target.affine if options.output_unit[0] == 'v': grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid) grid = grid - spatial.identity_grid(grid.shape[:-1], **utils.backend(grid)) else: grid = grid - spatial.affine_grid( oaffine.to(**utils.backend(grid)), grid.shape[:-1]) io.volumes.savef(grid, options.output.format(ext='.nii.gz'), affine=oaffine) else: if len(options.transformations) > 1: raise RuntimeError('Something weird happened: ' 'multiple transforms and no target') io.transforms.savef(options.transformations[0].affine, options.output.format(ext='.lta'))
def vexp(inp, type='displacement', unit='voxel', inverse=False, bound='dft', steps=8, device=None, output=None): """Exponentiate a stationary velocity fields. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. type : {'displacement', 'transformation'}, default='displacement' Whether to return a displacement field (coord-to-shift) or a transformation field (coord-to-coord). unit : {'voxel', 'mm'}, default='voxel' Whether to return displacement/coordinates in voxel or in mm. If mm, the input orientation matrix is used to convert voxels to mm. inverse : bool, default=False Whether to return the inverse field. bound : str, default='dft' Boundary conditions. steps : int, default=8 Number of scaling and squaring steps. device : str, optional Device to use. output : str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. Returns ------- output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the output tensor and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = None # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.fdata(device=device), f.affine) if output is None: output = '{dir}{sep}{base}.vexp{ext}' dir, base, ext = py.fileparts(fname) else: if torch.is_tensor(inp): inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3])) dat, aff = inp dat = dat.to(device=device) aff = aff.to(device=device) # exponentiate dat = spatial.exp(dat[None], inverse=inverse, steps=steps, bound=bound, inplace=True, displacement=(type.lower()[0] == 'd'))[0] if unit == 'mm': # if type.lower()[0] == 'd': # vx = spatial.voxel_size(aff) # dat *= vx # else: dat = spatial.affine_matvec(aff, dat) if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff.cpu()) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff.cpu()) if is_file: return output else: return dat, aff
def smart_exp(vel, **kwargs): """spatial.exp that accepts None vel""" if vel is not None: vel = spatial.exp(vel, **kwargs) return vel
def write_data(options): device = torch.device(options.device) backend = dict(dtype=torch.float, device='cpu') need_inv = False for loss in options.losses: if loss.fixed and (loss.fixed.resliced or loss.fixed.updated): need_inv = True break # affine matrix lin = None for trf in options.transformations: if isinstance(trf, struct.Linear): q = trf.dat.to(**backend) B = trf.basis.to(**backend) lin = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(3, **backend) lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift) break # non-linear displacement field d = None id = None d_aff = None for trf in options.transformations: if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') if need_inv: id = grid_inv(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if need_inv: id = spatial.exp(d[None], displacement=True, inverse=True)[0] d = spatial.exp(d[None], displacement=True)[0] d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: moving = match.moving fixed = match.fixed prm = dict(interpolation=moving.interpolation, bound=moving.bound, extrapolate=moving.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=d, affine=d_aff) if moving.updated: update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm) if moving.resliced: reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm) if not fixed: continue prm = dict(interpolation=fixed.interpolation, bound=fixed.bound, extrapolate=fixed.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=id, affine=d_aff) if fixed.updated: update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm) if fixed.resliced: reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
def forward(): """Forward pass up to the loss""" loss = 0 # affine matrix A = None for trf in options.transformations: trf.update() if isinstance(trf, struct.Linear): q = trf.optdat.to(**backend) # print(q.tolist()) B = trf.basis.to(**backend) A = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(options.dim, **backend) A = A.clone() # needed because expm is a custom autograd.Function A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift) for loss1 in trf.losses: loss += loss1.call(q) break # non-linear displacement field d = None d_aff = None for trf in options.transformations: if not trf.isfree(): continue if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if not trf.smalldef: # penalty on velocity fields for loss1 in trf.losses: loss += loss1.call(d) d = spatial.exp(d[None], displacement=True)[0] if trf.smalldef: # penalty on exponentiated transform for loss1 in trf.losses: loss += loss1.call(d) d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: if not match.fixed: continue nb_levels = len(match.fixed.dat) prm = dict(interpolation=match.moving.interpolation, bound=match.moving.bound, extrapolate=match.moving.extrapolate) # loop over pyramid levels for moving, fixed in zip(match.moving.dat, match.fixed.dat): moving_dat, moving_aff = moving fixed_dat, fixed_aff = fixed moving_dat = moving_dat.to(**backend) moving_aff = moving_aff.to(**backend) fixed_dat = fixed_dat.to(**backend) fixed_aff = fixed_aff.to(**backend) # affine-corrected moving space if A is not None: Ms = affine_matmul(A, moving_aff) else: Ms = moving_aff if d is not None: # fixed to param Mt = affine_lmdiv(d_aff, fixed_aff) if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]): g = smalldef(d) else: g = affine_grid(Mt, fixed_dat.shape[1:]) g = g + pull_grid(d, g) # param to moving Ms = affine_lmdiv(Ms, d_aff) g = affine_matvec(Ms, g) else: # fixed to moving Mt = fixed_aff Ms = affine_lmdiv(Ms, Mt) g = affine_grid(Ms, fixed_dat.shape[1:]) # pull moving image warped_dat = pull(moving_dat, g, **prm) loss += match.call(warped_dat, fixed_dat) / float(nb_levels) # import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.subplot(1, 2, 2) # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach()) # plt.axis('off') # plt.show() return loss
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target) oaffine = options.target.affine # 4) Loop across input files opt = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) output = utils.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, **backend) dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: grid = build_from_target(file) oaffine = file.affine mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt) dat = utils.movedim(dat, 0, -1) io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.json: with open(trf.json) as f: prm = json.load(f) prm['voxel_size'] = spatial.voxel_size(trf.affine) trf.dat = spatial.shoot(trf.dat[None], displacement=True, return_inverse=trf.inv) if trf.inv: trf.dat = trf.dat[-1] else: trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv) trf.dat = trf.dat[0] # drop batch dimension trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(affine, shape): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(affine.to(**backend), shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target.affine, options.target.shape) oaffine = options.target.affine # 4) Loop across input files opt_pull0 = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) opt_coeff = dict(interpolation=options.interpolation, bound=options.bound, dim=3, inplace=True) output = py.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): is_label = isinstance(options.interpolation, str) and options.interpolation == 'l' ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') if is_label: backend_int = dict(dtype=torch.long, device=backend['device']) dat = io.volumes.load(file.fname, **backend_int) opt_pull = dict(opt_pull0) opt_pull['interpolation'] = 1 else: dat = io.volumes.loadf(file.fname, rand=options.interpolation > 0, **backend) opt_pull = opt_pull0 dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: oaffine = file.affine oshape = file.shape if options.voxel_size: ovx = utils.make_vector(options.voxel_size, 3, dtype=oaffine.dtype) factor = spatial.voxel_size(oaffine) / ovx oaffine, oshape = spatial.affine_resize(oaffine, oshape, factor=factor, anchor='f') grid = build_from_target(oaffine, oshape) mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) if options.prefilter and not is_label: dat = spatial.spline_coeff_nd(dat, **opt_coeff) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull) dat = utils.movedim(dat, 0, -1) if is_label: io.volumes.save(dat, ofname, like=file.fname, affine=oaffine) else: io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)