def upsample(self, factor=2, **kwargs): kwargs.setdefault('interpolation', 1) kwargs.setdefault('bound', 'dft') kwargs.setdefault('anchor', 'c') dat, aff = spatial.resize_grid(self.dat, factor, type='displacement', affine=self.affine, **kwargs) return type(self)(dat, aff, dim=self.dim)
def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid
def upsample_(self, factor=2, **kwargs): kwargs.setdefault('interpolation', 1) kwargs.setdefault('bound', 'dft') kwargs.setdefault('anchor', 'c') self.dat, self.affine = spatial.resize_grid( self.dat, factor, type='disp', affine=self.affine, **kwargs) return self
def forward(self, grid, affine=None, **overload): """ Parameters ---------- grid : (batch, *spatial_in, ndim) tensor Input grid to deform affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input image. If provided, the orientation matrix of the resized image is returned as well. overload : dict All parameters defined at build time can be overridden at call time. Returns ------- resized : (batch, *spatial_out, ndim) tensor Resized image. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ kwargs = { 'factor': overload.get('factor', self.factor), 'shape': overload.get('shape', self.shape), 'type': overload.get('type', self.type), 'anchor': overload.get('anchor', self.anchor), 'interpolation': overload.get('interpolation', self.interpolation), 'bound': overload.get('bound', self.bound), 'extrapolate': overload.get('extrapolate', self.extrapolate), } return spatial.resize_grid(grid, affine=affine, **kwargs)
def forward(self, grid, affine=None, output_shape=None, **overload): """ Parameters ---------- grid : (batch, *spatial_in, ndim) tensor Input grid to deform affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input image. If provided, the orientation matrix of the resized image is returned as well. output_shape : bool, optional Returns ------- resized : (batch, *spatial_out, ndim) tensor Resized image. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ output_shape = output_shape or self.shape kwargs = { 'factor': overload.get('factor', self.factor), 'shape': output_shape, 'type': self.type, 'anchor': self.anchor, 'interpolation': self.interpolation, 'bound': self.bound, 'extrapolate': self.extrapolate, 'prefilter': self.prefilter } return spatial.resize_grid(grid, affine=affine, **kwargs)
def exp(prm): disp = spatial.resize_grid(prm, type='displacement', shape=target.shape[2:], interpolation=3, bound='dft') grid = disp + spatial.identity_grid(target.shape[2:], **backend) return disp, grid
def forward(self, source, target, source_seg=None, target_seg=None, *, _loss=None, _metric=None): # sanity checks check.dim(self.dim, source, target, source_seg, target_seg) check.shape(target, source, dims=[0], broadcast_ok=True) check.shape(target, source, dims=range(2, self.dim + 2)) check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True) check.shape(target_seg, source_seg, dims=range(2, self.dim + 2)) # chain operations source_and_target = torch.cat((source, target), dim=1) affine_prm = self.cnn(source_and_target) affine_prm = affine_prm.reshape(affine_prm.shape[:2]) affine = [] for prm in affine_prm: affine.append(self.exp(prm)) affine = torch.stack(affine, dim=0) grid = self.grid(affine, shape=target.shape[2:]) deformed_source = self.pull(source, grid) if source_seg is not None: if source_seg.shape[2:] != source.shape[2:]: grid = spatial.resize_grid(grid, shape=source_seg.shape[2:]) deformed_source_seg = self.pull(source_seg, grid) else: deformed_source_seg = None # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], affine=[affine_prm], segmentation=[deformed_source_seg, target_seg]) if source_seg is None: return deformed_source, affine_prm else: return deformed_source, deformed_source_seg, affine_prm
def ffd_exp(prm, shape, order=3, bound='dft', returns='disp'): """Transform FFD parameters into a displacement or transformation grid. Parameters ---------- prm : (..., *spatial, dim) FFD parameters shape : sequence[int] Exponentiated shape order : int, default=3 Spline order bound : str, default='dft' Boundary condition returns : {'disp', 'grid', 'disp+grid'}, default='grid' What to return: - 'disp' -> displacement grid - 'grid' -> transformation grid Returns ------- disp : (..., *shape, dim), optional Displacement grid grid : (..., *shape, dim), optional Transformation grid """ backend = dict(dtype=prm.dtype, device=prm.device) dim = prm.shape[-1] batch = prm.shape[:-(dim + 1)] prm = prm.reshape([-1, *prm.shape[-(dim + 1):]]) disp = resize_grid(prm, type='displacement', shape=shape, interpolation=order, bound=bound) disp = disp.reshape(batch + disp.shape[1:]) grid = disp + identity_grid(shape, **backend) if 'disp' in returns and 'grid' in returns: return disp, grid elif 'disp' in returns: return disp elif 'grid' in returns: return grid
def free(self): """Free the next batch/ladder of parameters""" if not self.freeable(): return print('Free nonlin') if not hasattr(self, 'optdat'): self.optdat = torch.nn.Parameter(self.dat, requires_grad=True) self.dat = self.optdat else: *self.pyramid, pre_level = self.pyramid self.dat = self.dat.detach() factor = pre_level - self.pyramid[-1] new_shape = [s * (2**factor) for s in self.dat.shape[:-1]] self.dat, self.affine = spatial.resize_grid( self.dat[None], shape=new_shape, type='displacement', affine=self.affine[None]) self.dat = self.dat[0] self.affine = self.affine[0] self.optdat = torch.nn.Parameter(self.dat, requires_grad=True) self.dat = self.optdat
def write_data(files, options): device = options.gpu if isinstance(options.gpu, str) else f'cuda:{options.gpu}' backend = dict(dtype=torch.float32, device=device) ofiles = py.make_list(options.output, len(files)) for file, ofile in zip(files, ofiles): ofile = ofile.format(dir=file.dir, base=file.base, ext=file.ext, sep=os.sep) print(f'Resizing: {file.fname}\n' f' -> {ofile}') dat = io.loadf(file.fname, **backend) dat = dat.reshape([*file.shape, file.channels]) # compute resizing factor input_vx = spatial.voxel_size(file.affine) if options.voxel_size: if options.factor: raise ValueError('Cannot use both factor and voxel size') factor = input_vx / utils.make_vector(options.voxel_size, 3) elif options.factor: factor = utils.make_vector(options.factor, 3) elif options.shape: input_shape = utils.make_vector(dat.shape[:-1], 3, dtype=torch.float32) output_shape = utils.make_vector(options.shape, 3, dtype=torch.float32) factor = output_shape / input_shape else: raise ValueError('Need at least one of factor/voxel_size/shape') factor = factor.tolist() # check if output shape is provided if options.shape: output_shape = py.ensure_list(options.shape, 3) else: output_shape = None # Perform resize opt = dict( anchor=options.anchor, bound=options.bound, interpolation=options.interpolation, prefilter=options.prefilter, ) if options.grid: dat, affine = spatial.resize_grid(dat[None], factor, output_shape, type=options.grid, affine=file.affine, **opt)[0] else: dat = utils.movedim(dat, -1, 0) dat, affine = spatial.resize(dat[None], factor, output_shape, affine=file.affine, **opt) dat = utils.movedim(dat[0], 0, -1) # Write output file io.volumes.savef(dat, ofile, like=file.fname, affine=affine)
def forward(self, source, target, source_seg=None, target_seg=None, *, _loss=None, _metric=None): """ Parameters ---------- source : tensor (batch, channel, *spatial) Source/moving image target : tensor (batch, channel, *spatial) Target/fixed image _loss : dict, optional If provided, all registered losses are computed and appended. _metric : dict, optional If provided, all registered metrics are computed and appended. Returns ------- deformed_source : tensor (batch, channel, *spatial) Deformed source image affine_prm : tensor (batch,, *spatial, len(spatial)) affine Lie parameters """ # sanity checks check.dim(self.dim, source, target, source_seg, target_seg) check.shape(target, source, dims=[0], broadcast_ok=True) check.shape(target, source, dims=range(2, self.dim + 2)) check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True) check.shape(target_seg, source_seg, dims=range(2, self.dim + 2)) # chain operations source_and_target = torch.cat((source, target), dim=1) # generate affine affine_prm = self.cnn(source_and_target) affine_prm = affine_prm.reshape(affine_prm.shape[:2]) # generate velocity velocity = self.unet(source_and_target) velocity = channel2last(velocity) # generate deformation grid grid = self.exp(velocity, affine_prm) # deform deformed_source = self.pull(source, grid) if source_seg is not None: if source_seg.shape[2:] != source.shape[2:]: grid = spatial.resize_grid(grid, shape=source_seg.shape[2:]) deformed_source_seg = self.pull(source_seg, grid) else: deformed_source_seg = None # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], velocity=[velocity], segmentation=[deformed_source_seg, target_seg], affine=[affine_prm]) if deformed_source_seg is None: return deformed_source, velocity, affine_prm else: return deformed_source, deformed_source_seg, velocity, affine_prm
def forward(self, source, target, source_seg=None, target_seg=None, *, _loss=None, _metric=None): """ Parameters ---------- source : tensor (batch, channel, *spatial) Source/moving image target : tensor (batch, channel, *spatial) Target/fixed image source_seg : tensor (batch, classes, *spatial), optional Source/moving segmentation target_seg : tensor (batch, classes, *spatial), optional Target/fixed segmentation Other Parameters ---------------- _loss : dict, optional If provided, all registered losses are computed and appended. _metric : dict, optional If provided, all registered metrics are computed and appended. Returns ------- deformed_source : tensor (batch, channel, *spatial) Deformed source image deformed_source_seg : tensor (batch, classes, *spatial), optional Deformed source segmentation velocity : tensor (batch,, *spatial, len(spatial)) Velocity field """ # sanity checks check.dim(self.dim, source, target) check.shape(target, source, dims=[0], broadcast_ok=True) check.shape(target, source, dims=range(2, self.dim+2)) check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True) check.shape(target_seg, source_seg, dims=range(2, self.dim+2)) # chain operations source_and_target = torch.cat((source, target), dim=1) velocity = self.unet(source_and_target) velocity = core.utils.channel2last(velocity) grid = self.exp(velocity) deformed_source = self.pull(source, grid) if source_seg is not None: if source_seg.shape[2:] != source.shape[2:]: grid = spatial.resize_grid(grid, shape=source_seg.shape[2:]) deformed_source_seg = self.pull(source_seg, grid) else: deformed_source_seg = None # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], velocity=[velocity], segmentation=[deformed_source_seg, target_seg]) if source_seg is None: return deformed_source, velocity else: return deformed_source, deformed_source_seg, velocity