def _build_pyramid(self, dat, levels, method, dim, bound, mask=None, preview=None): levels = list(levels) indexed_levels = list(enumerate(levels)) indexed_levels.sort(key=lambda x: x[1]) nb_levels = max(levels) if mask is not None: mask = mask.to(dat.device) dats = [dat] * levels.count(0) masks = [mask] * levels.count(0) previews = [preview] * levels.count(0) if mask is not None: mask = mask.to(dat.dtype) if preview is not None: preview = preview.to(dat.dtype) for level in range(1, nb_levels+1): shape = dat.shape[-dim:] kernel_size = [min(2, s) for s in shape] if method[0] == 'g': # gaussian pyramid # We assume the original data has a PSF of 1 input voxel. # We smooth by an additional 1-vx FWHM so that the data has a # PSF of 2 input voxels == 1 output voxel, then subsample. smooth = lambda x: spatial.smooth(x, fwhm=1, stride=2, dim=dim, bound=bound) elif method[0] == 'a': # average window smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size, stride=2, reduction='mean') elif method[0] == 'm': # median window smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size, stride=2, reduction='median') elif method[0] == 's': # strides slicer = [slice(None, None, 2)] * dim smooth = lambda x: x[(Ellipsis, *slicer)] else: raise ValueError(method) dat = smooth(dat) if mask is not None: mask = smooth(mask) if preview is not None: preview = smooth(preview) dats += [dat] * levels.count(level) masks += [mask] * levels.count(level) previews += [preview] * levels.count(level) reordered_dats = [None] * len(levels) reordered_masks = [None] * len(levels) reordered_previews = [None] * len(levels) for (i, level), dat, mask, preview \ in zip(indexed_levels, dats, masks, previews): reordered_dats[i] = dat reordered_masks[i] = mask reordered_previews[i] = preview return reordered_dats, reordered_masks, reordered_previews
def downsample(x, aff_in, vx_out): """ Downsample an image (by an integer factor) to approximately match a target voxel size """ vx_in = spatial.voxel_size(aff_in) dim = len(vx_in) vx_out = utils.make_vector(vx_out, dim) factor = (vx_out / vx_in).clamp_min(1).floor().long() if (factor == 1).all(): return x, aff_in factor = factor.tolist() x, aff_out = spatial.pool(dim, x, factor, affine=aff_in) return x, aff_out
def forward(self, x, **overload): """ Parameters ---------- x : (batch, channel, *spatial) tensor Tensor to pool overload : dict Most parameters defined at build time can be overriden at call time Returns ------- x : (batch, channel, *spatial_out) tensor Pooled tensor indices : (batch, channel, *spatial_out, dim) tensor, if `return_indices` Indices of input elements. """ dim = self.dim kernel_size = make_list(overload.get('kernel_size', self.kernel_size), dim) stride = make_list(overload.get('stride', self.stride), dim) padding = make_list(overload.get('padding', self.padding), dim) dilation = make_list(overload.get('dilation', self.dilation), dim) reduction = overload.get('reduction', self.reduction) return_indices = overload.get('return_indices', self.return_indices) # Activation activation = overload.get('activation', self.activation) if isinstance(activation, str): activation = _map_activations.get(activation.lower(), None) activation = (activation() if inspect.isclass(activation) else activation if callable(activation) else None) x = pool(dim, x, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, reduction=reduction, return_indices=return_indices) if activation: x = activation(x) return x
def forward(self, x, return_indices=None): """ Parameters ---------- x : (batch, channel, *spatial) tensor Tensor to pool return_indices : bool, default=self.return_indices Returns ------- x : (batch, channel, *spatial_out) tensor Pooled tensor indices : (batch, channel, *spatial_out, dim) tensor, if `return_indices` Indices of input elements. """ return_indices = self.return_indices if return_indices is None: return_indices = self.return_indices x = pool(self.dim, x, kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, padding=self.padding, reduction=self.reduction, return_indices=return_indices, ceil=self.ceil) if return_indices: x, ind = x if self.activation: x = self.activation(x) return (x, ind) if return_indices else x
def pool(inp, window=3, stride=None, method='mean', dim=3, output=None, device=None): """Pool a ND volume, while preserving the orientation matrices. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. window : [sequence of] int, default=3 Window size stride : [sequence of] int, optional Stride between output elements. By default, it is the same as `window`. method : {'mean', 'sum', 'min', 'max', 'median'}, default='mean' Pooling function. dim : int, default=3 Number of spatial dimensions. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. Returns ------- output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the pooled data and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.fdata(device=device), f.affine) if output is None: output = '{dir}{sep}{base}.pool{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp dat = dat.to(device) dim = dim or aff0.shape[-1] - 1 # `pool` needs the spatial dimensions at the end spatial_in = dat.shape[:dim] batch = dat.shape[dim:] dat = dat.reshape([*spatial_in, -1]) dat = utils.movedim(dat, -1, 0) dat, aff = spatial.pool(dim, dat, kernel_size=window, stride=stride, reduction=method, affine=aff0) dat = utils.movedim(dat, 0, -1) dat = dat.reshape([*dat.shape[:dim], *batch]) if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff) if is_file: return output else: return dat, aff
def correct_smooth(x, sigma=None, lam=10, gamma=10, downsample=None, max_iter=16, max_rls=8, tol=1e-6, verbose=False, device=None): """Correct the intensity non-uniformity in a SPIM image. The signal is modelled as: f = exp(s + b) + eps, with a penalty on the (Squared) gradients of s and on the (squared) curvature of b. Parameters ---------- x : tensor SPIM image with the z dimension last and the z=0 plane first sigma : float, optional Noise standard deviation. Default: educated guess. lam : float, default=10 Regularisation on the signal. gamma : float, default=10 Regularisation on the bias field. max_iter : int, default=16 Maximum number of Newton iterations. max_rls : int, default=8 Maximum number of reweighting iterations. If 1, this is effectively an l2 regularisation. tol : float, default=1e-6 Tolerance for early stopping. verbose : int or bool, default=False Verbosity level device : torch.device, default=x.device Use this device during fitting. Returns ------- y : tensor Fitted image bias : float Fitted bias x : float Corrected image """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) dim = x.dim() # downsampling if downsample: x0 = x downsample = py.make_list(downsample, dim) x = spatial.pool(dim, x, downsample) shape = x.shape x = x.to(device) # noise educated guess: assume SNR=5 at z=1/2 center = tuple(slice(s // 3, 2 * s // 3) for s in shape) sigma = sigma or x[center].median() / 5 lam = lam**2 * sigma**2 gamma = gamma**2 * sigma**2 regy = lambda y, w: spatial.regulariser( y[None], membrane=lam, dim=dim, weights=w)[0] regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0] solvey = lambda h, g, w: spatial.solve_field_sym( h[None], g[None], membrane=lam, dim=dim, weights=w)[0] solveb = lambda h, g: spatial.solve_field_sym( h[None], g[None], bending=gamma, dim=dim)[0] # init l1 = max_rls > 1 if l1: w = torch.ones_like(x)[None] llw = w.sum() max_rls = 10 else: w = None llw = 0 max_rls = 1 logb = torch.zeros_like(x) logy = x.clamp_min(1e-3).log_() y = logy.exp() b = logb.exp() fit = y * b res = fit - x llx = res.square().sum() lly = (regy(logy, w).mul_(logy)).sum() llb = (regb(logb).mul_(logb)).sum() ll0 = llx + lly + llb + llw ll1 = ll0 for it_ls in range(max_rls): for it in range(max_iter): # update bias g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regb(logb) logb -= solveb(h, g) logb0 = logb.mean() logb -= logb0 logy += logb0 # update fit / ll llb = (regb(logb).mul_(logb)).sum() b = torch.exp(logb, out=b) y = torch.exp(logy, out=y) fit = y * b res = fit - x # update y g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regy(logy, w) logy -= solvey(h, g, w) # update fit / ll y = torch.exp(logy, out=y) fit = y * b res = fit - x lly = (regy(logy, w).mul_(logy)).sum() # compute objective llx = res.square().sum() ll = llx + lly + llb + llw gain = (ll1 - ll) / ll0 ll1 = ll if verbose: end = '\n' if verbose > 1 else '\r' pre = f'{it_ls:3d} | ' if l1 else '' print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and abs(gain) < tol: break if l1: w, llw = spatial.membrane_weights(logy[None], lam, dim=dim, return_sum=True) ll0 = ll if verbose: print('') if downsample: b = spatial.resize(logb.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() y = spatial.resize(logy.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() x = x0 else: y = torch.exp(logy, out=y) x = x / b return y, b, x