def forward(self, x, noise=None, return_resolution=False): if noise is not None: noise = noise.expand(x.shape) dim = x.dim() - 2 backend = utils.backend(x) resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1], **backend) resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1], **backend) all_resolutions = [] out = torch.empty_like(x) for b in range(len(x)): for c in range(x.shape[1]): resolution = self.resolution(resolution_exp[c], resolution_scale[c]).sample() resolution = resolution.clamp_min(1) fwhm = [resolution] * dim y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2') if noise is not None: y += noise[b, c] factor = [1/resolution] * dim y = y[None, None] # need batch and channel for resize y = resize(y, factor=factor, anchor='f') factor = [resolution] * dim all_resolutions.append(factor) y = resize(y, factor=factor, shape=x.shape[2:], anchor='f') out[b, c] = y[0, 0] all_resolutions = utils.as_tensor(all_resolutions, **backend) return (out, all_resolutions) if return_resolution else out
def _resize(maps, rls, aff, shape): for map in maps: map.volume = spatial.resize(map.volume[None, None, ...], shape=shape)[0, 0] map.affine = aff maps.affine = aff if rls is not None: if rls.dim() == len(shape): rls = spatial.resize(rls[None, None], shape=shape)[0, 0] else: rls = spatial.resize(rls[None], shape=shape)[0] return maps, rls
def _resize(maps, rls, aff, shape): """Resize (prolong) current maps to a target resolution""" maps.volume = spatial.resize(maps.volume[None], shape=shape)[0] maps.affine = aff # for map in maps: # map.volume = spatial.resize(map.volume[None, None, ...], # shape=shape)[0, 0] # map.affine = aff # maps.affine = aff if rls is not None: if rls.dim() == len(shape): rls = spatial.resize(rls[None, None], shape=shape)[0, 0] else: rls = spatial.resize(rls[None], shape=shape)[0] return maps, rls
def forward(self, image, affine=None, **overload): """ Parameters ---------- image : (batch, channel, *spatial_in) tensor Input image to deform affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input image. If provided, the orientation matrix of the resized image is returned as well. overload : dict All parameters defined at build time can be overridden at call time. Returns ------- resized : (batch, channel, ...) tensor Resized image. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ kwargs = { 'factor': overload.get('factor', self.factor), 'shape': overload.get('shape', self.shape), 'anchor': overload.get('anchor', self.anchor), 'interpolation': overload.get('interpolation', self.interpolation), 'bound': overload.get('bound', self.bound), 'extrapolate': overload.get('extrapolate', self.extrapolate), } return spatial.resize(image, affine=affine, **kwargs)
def forward(self, image, affine=None, output_shape=None): """ Parameters ---------- image : (batch, channel, *spatial_in) tensor Input image to deform affine : (batch, ndim[+1], ndim+1), optional Orientation matrix of the input image. If provided, the orientation matrix of the resized image is returned as well. output_shape : sequence[int], optional Returns ------- resized : (batch, channel, ...) tensor Resized image. affine : (batch, ndim[+1], ndim+1) tensor, optional Orientation matrix """ outshape = self.shape(image, output_shape=output_shape) kwargs = { 'shape': outshape[2:], 'factor': self.factor, 'anchor': self.anchor, 'interpolation': self.interpolation, 'bound': self.bound, 'extrapolate': self.extrapolate, 'prefilter': self.prefilter, } return spatial.resize(image, affine=affine, **kwargs)
def resize(self): affine, shape = spatial.affine_resize(self.affine0, self.shape0, 1 / (2**(self.level - 1))) scl0 = spatial.voxel_size(self.affine0).prod() scl = spatial.voxel_size(affine).prod() / scl0 self.lam_scale = scl for map in self.maps: map.volume = spatial.resize(map.volume[None, None, ...], shape=shape)[0, 0] map.affine = affine self.maps.affine = affine if self.rls is not None: if self.rls.dim() == len(shape): self.rls = spatial.resize(self.rls[None, None], hape=shape)[0, 0] else: self.rls = spatial.resize(self.rls[None], shape=shape)[0] self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
def forward(self, x, affine=None): """ Parameters ---------- x : (X, Y, Z) tensor or str affine : (4, 4) tensor, optional Returns ------- seg : (32, oX, oY, oZ) tensor Segmentation resliced : (oX, oY, oZ) tensor Input resliced to 1 mm RAS affine : (4, 4) tensor Output orientation matrix """ if self.verbose: print('Preprocessing... ', end='', flush=True) if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) x = SynthPreproc.addnoise(x) if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) fwhm = 0.25 * vx.reciprocal() fwhm[vx > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] oshape = x.shape x, crop = SynthPreproc.crop(x) x = SynthPreproc.preproc(x)[None, None] if self.verbose: print('done.', flush=True) print('Segmenting... ', end='', flush=True) s, x = super().forward(x)[0], x[0, 0] if self.verbose: print('done.', flush=True) print('Postprocessing... ', end='', flush=True) s = self.relabel(s.argmax(0)) x = SynthPreproc.pad(x, oshape, crop) s = SynthPreproc.pad(s, oshape, crop) if self.verbose: print('done.', flush=True) return s, x, affine
def resize(cls, x, affine, target_vx=1): target_vx = utils.make_vector(target_vx, x.dim(), **utils.backend(affine)) vx = spatial.voxel_size(affine) factor = vx / target_vx fwhm = 0.25 * factor.reciprocal() fwhm[factor > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], factor.tolist(), affine=affine) x = x[0, 0] return x, affine
def load(self, x, affine=None): if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) affine_original = affine x_original = x.shape if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] return x, affine, x_original, affine_original
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional channel : int, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) channel = overload.get('channel', self.channel) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # device/dtype nb_dim = len(shape) mean = utils.make_vector(self.mean, channel, **backend) amplitude = utils.make_vector(self.amplitude, channel, **backend) fwhm = utils.make_vector(self.fwhm, nb_dim, **backend) # sample spline coefficients nodes = [(s/f).ceil().int().item() for s, f in zip(shape, fwhm)] sample = torch.randn([batch, channel, *nodes], **backend) sample *= utils.unsqueeze(amplitude, -1, nb_dim) sample = spatial.resize(sample, shape=shape, interpolation=self.basis, bound='dct2', prefilter=False) sample += utils.unsqueeze(mean, -1, nb_dim) return sample
def write_data(files, options): device = options.gpu if isinstance(options.gpu, str) else f'cuda:{options.gpu}' backend = dict(dtype=torch.float32, device=device) ofiles = py.make_list(options.output, len(files)) for file, ofile in zip(files, ofiles): ofile = ofile.format(dir=file.dir, base=file.base, ext=file.ext, sep=os.sep) print(f'Resizing: {file.fname}\n' f' -> {ofile}') dat = io.loadf(file.fname, **backend) dat = dat.reshape([*file.shape, file.channels]) # compute resizing factor input_vx = spatial.voxel_size(file.affine) if options.voxel_size: if options.factor: raise ValueError('Cannot use both factor and voxel size') factor = input_vx / utils.make_vector(options.voxel_size, 3) elif options.factor: factor = utils.make_vector(options.factor, 3) elif options.shape: input_shape = utils.make_vector(dat.shape[:-1], 3, dtype=torch.float32) output_shape = utils.make_vector(options.shape, 3, dtype=torch.float32) factor = output_shape / input_shape else: raise ValueError('Need at least one of factor/voxel_size/shape') factor = factor.tolist() # check if output shape is provided if options.shape: output_shape = py.ensure_list(options.shape, 3) else: output_shape = None # Perform resize opt = dict( anchor=options.anchor, bound=options.bound, interpolation=options.interpolation, prefilter=options.prefilter, ) if options.grid: dat, affine = spatial.resize_grid(dat[None], factor, output_shape, type=options.grid, affine=file.affine, **opt)[0] else: dat = utils.movedim(dat, -1, 0) dat, affine = spatial.resize(dat[None], factor, output_shape, affine=file.affine, **opt) dat = utils.movedim(dat[0], 0, -1) # Write output file io.volumes.savef(dat, ofile, like=file.fname, affine=affine)
def resize(maps, rls, aff, shape): maps.volume = spatial.resize(maps.volume, shape=shape) maps.affine = aff if rls is not None: rls = spatial.resize(rls, shape=shape) return maps, rls
def correct_smooth(x, sigma=None, lam=10, gamma=10, downsample=None, max_iter=16, max_rls=8, tol=1e-6, verbose=False, device=None): """Correct the intensity non-uniformity in a SPIM image. The signal is modelled as: f = exp(s + b) + eps, with a penalty on the (Squared) gradients of s and on the (squared) curvature of b. Parameters ---------- x : tensor SPIM image with the z dimension last and the z=0 plane first sigma : float, optional Noise standard deviation. Default: educated guess. lam : float, default=10 Regularisation on the signal. gamma : float, default=10 Regularisation on the bias field. max_iter : int, default=16 Maximum number of Newton iterations. max_rls : int, default=8 Maximum number of reweighting iterations. If 1, this is effectively an l2 regularisation. tol : float, default=1e-6 Tolerance for early stopping. verbose : int or bool, default=False Verbosity level device : torch.device, default=x.device Use this device during fitting. Returns ------- y : tensor Fitted image bias : float Fitted bias x : float Corrected image """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) dim = x.dim() # downsampling if downsample: x0 = x downsample = py.make_list(downsample, dim) x = spatial.pool(dim, x, downsample) shape = x.shape x = x.to(device) # noise educated guess: assume SNR=5 at z=1/2 center = tuple(slice(s // 3, 2 * s // 3) for s in shape) sigma = sigma or x[center].median() / 5 lam = lam**2 * sigma**2 gamma = gamma**2 * sigma**2 regy = lambda y, w: spatial.regulariser( y[None], membrane=lam, dim=dim, weights=w)[0] regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0] solvey = lambda h, g, w: spatial.solve_field_sym( h[None], g[None], membrane=lam, dim=dim, weights=w)[0] solveb = lambda h, g: spatial.solve_field_sym( h[None], g[None], bending=gamma, dim=dim)[0] # init l1 = max_rls > 1 if l1: w = torch.ones_like(x)[None] llw = w.sum() max_rls = 10 else: w = None llw = 0 max_rls = 1 logb = torch.zeros_like(x) logy = x.clamp_min(1e-3).log_() y = logy.exp() b = logb.exp() fit = y * b res = fit - x llx = res.square().sum() lly = (regy(logy, w).mul_(logy)).sum() llb = (regb(logb).mul_(logb)).sum() ll0 = llx + lly + llb + llw ll1 = ll0 for it_ls in range(max_rls): for it in range(max_iter): # update bias g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regb(logb) logb -= solveb(h, g) logb0 = logb.mean() logb -= logb0 logy += logb0 # update fit / ll llb = (regb(logb).mul_(logb)).sum() b = torch.exp(logb, out=b) y = torch.exp(logy, out=y) fit = y * b res = fit - x # update y g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regy(logy, w) logy -= solvey(h, g, w) # update fit / ll y = torch.exp(logy, out=y) fit = y * b res = fit - x lly = (regy(logy, w).mul_(logy)).sum() # compute objective llx = res.square().sum() ll = llx + lly + llb + llw gain = (ll1 - ll) / ll0 ll1 = ll if verbose: end = '\n' if verbose > 1 else '\r' pre = f'{it_ls:3d} | ' if l1 else '' print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and abs(gain) < tol: break if l1: w, llw = spatial.membrane_weights(logy[None], lam, dim=dim, return_sum=True) ll0 = ll if verbose: print('') if downsample: b = spatial.resize(logb.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() y = spatial.resize(logy.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() x = x0 else: y = torch.exp(logy, out=y) x = x / b return y, b, x