def correct_smooth(x, sigma=None, lam=10, gamma=10, downsample=None, max_iter=16, max_rls=8, tol=1e-6, verbose=False, device=None): """Correct the intensity non-uniformity in a SPIM image. The signal is modelled as: f = exp(s + b) + eps, with a penalty on the (Squared) gradients of s and on the (squared) curvature of b. Parameters ---------- x : tensor SPIM image with the z dimension last and the z=0 plane first sigma : float, optional Noise standard deviation. Default: educated guess. lam : float, default=10 Regularisation on the signal. gamma : float, default=10 Regularisation on the bias field. max_iter : int, default=16 Maximum number of Newton iterations. max_rls : int, default=8 Maximum number of reweighting iterations. If 1, this is effectively an l2 regularisation. tol : float, default=1e-6 Tolerance for early stopping. verbose : int or bool, default=False Verbosity level device : torch.device, default=x.device Use this device during fitting. Returns ------- y : tensor Fitted image bias : float Fitted bias x : float Corrected image """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) dim = x.dim() # downsampling if downsample: x0 = x downsample = py.make_list(downsample, dim) x = spatial.pool(dim, x, downsample) shape = x.shape x = x.to(device) # noise educated guess: assume SNR=5 at z=1/2 center = tuple(slice(s // 3, 2 * s // 3) for s in shape) sigma = sigma or x[center].median() / 5 lam = lam**2 * sigma**2 gamma = gamma**2 * sigma**2 regy = lambda y, w: spatial.regulariser( y[None], membrane=lam, dim=dim, weights=w)[0] regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0] solvey = lambda h, g, w: spatial.solve_field_sym( h[None], g[None], membrane=lam, dim=dim, weights=w)[0] solveb = lambda h, g: spatial.solve_field_sym( h[None], g[None], bending=gamma, dim=dim)[0] # init l1 = max_rls > 1 if l1: w = torch.ones_like(x)[None] llw = w.sum() max_rls = 10 else: w = None llw = 0 max_rls = 1 logb = torch.zeros_like(x) logy = x.clamp_min(1e-3).log_() y = logy.exp() b = logb.exp() fit = y * b res = fit - x llx = res.square().sum() lly = (regy(logy, w).mul_(logy)).sum() llb = (regb(logb).mul_(logb)).sum() ll0 = llx + lly + llb + llw ll1 = ll0 for it_ls in range(max_rls): for it in range(max_iter): # update bias g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regb(logb) logb -= solveb(h, g) logb0 = logb.mean() logb -= logb0 logy += logb0 # update fit / ll llb = (regb(logb).mul_(logb)).sum() b = torch.exp(logb, out=b) y = torch.exp(logy, out=y) fit = y * b res = fit - x # update y g = h = fit h = (h * res).abs_() h.addcmul_(g, g) g *= res g += regy(logy, w) logy -= solvey(h, g, w) # update fit / ll y = torch.exp(logy, out=y) fit = y * b res = fit - x lly = (regy(logy, w).mul_(logy)).sum() # compute objective llx = res.square().sum() ll = llx + lly + llb + llw gain = (ll1 - ll) / ll0 ll1 = ll if verbose: end = '\n' if verbose > 1 else '\r' pre = f'{it_ls:3d} | ' if l1 else '' print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and abs(gain) < tol: break if l1: w, llw = spatial.membrane_weights(logy[None], lam, dim=dim, return_sum=True) ll0 = ll if verbose: print('') if downsample: b = spatial.resize(logb.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() y = spatial.resize(logy.to(x0.device)[None, None], downsample, shape=x0.shape, anchor='f')[0, 0].exp_() x = x0 else: y = torch.exp(logy, out=y) x = x / b return y, b, x
def zcorrect_square(x, decay=None, sigma=None, lam=10, max_iter=128, tol=1e-6, verbose=False): """Correct the z signal decay in a SPIM image. The signal is modelled as: f(z) = s(z) / (1 + b * z**2) + eps where z=0 is the top slice, s(z) is the theoretical signal if there was no absorption and b is the decay coefficient. Parameters ---------- x : (..., nz) tensor SPIM image with the z dimension last and the z=0 plane first decay : float, optional Initial guess for decay parameter. Default: educated guess. sigma : float, optional Noise standard deviation. Default: educated guess. lam : float, default=10 Regularisation. max_iter : int, default=128 tol : float, default=1e-6 verbose : int or bool, default=False Returns ------- y : tensor Corrected image decay : float Decay parameters """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) backend = utils.backend(x) shape = x.shape nz = shape[-1] x = x.reshape([-1, nz]) b = decay # decay educated guess: closed form two values at z=1/3 and z=2/3 z1 = nz // 3 z2 = 2 * nz // 3 x1 = x[:, z1].median() x2 = x[:, z2].median() z1 = float(z1)**2 z2 = float(z2)**2 b = b or (x2 - x1) / (x1 * z1 - x2 * z2) b = abs(b) y0 = x1 * (1 + b * z1) y0 = y0.item() b = b.item() if torch.is_tensor(b) else b # noise educated guess: assume SNR=5 at z=1/2 sigma = sigma or (y0 / (1 + b * (nz / 2)**2)) / 5 lam = lam**2 * sigma**2 reg = lambda y: spatial.regulariser(y[:, None], membrane=lam, dim=1)[:, 0] solve = lambda h, g: spatial.solve_field_sym( h[:, None], g[:, None], membrane=lam, dim=1)[:, 0] print(y0, b, sigma, lam) # init z2 = torch.arange(nz, **backend).square_() logy = torch.full_like(x, y0).log_() logb = torch.as_tensor(b, **backend) y = logy.exp() b = logb.exp() ll0 = (y / (1 + b * z2) - x).square_().sum() + (logy * reg(logy)).sum() ll1 = ll0 for it in range(max_iter): # exponentiate y = torch.exp(logy, out=y) fit = y / (1 + b * z2) res = fit - x # compute objective ll = res.square().sum() + (logy * reg(logy)).sum() gain = (ll1 - ll) / ll0 if verbose: end = '\n' if verbose > 1 else '\r' print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and gain < tol: break ll1 = ll # update decay g = -(z2 * b * y) / (z2 * b + 1).square() h = b * y - z2 * b.square() * y h *= z2 / (z2 * b + 1).pow(3) h = h.abs_() * res.abs() h += g.square() g *= res g = g.sum() h = h.sum() logb -= g / h # update fit b = torch.exp(logb, out=b) fit = y / (1 + b * z2) res = fit - x # ll = (fit - x).square().sum() + 1e3 * (logy[1:] - logy[:-1]).sum().square() # gain = (ll1 - ll) / ll0 # print(f'{it} | {ll.item()} | {gain.item()}', end='\n') # update y g = h = y / (z2 * b + 1) h = h.abs() * res.abs() h += g.square() g *= res g += reg(logy) logy -= solve(h, g) y = torch.exp(logy, out=y) y = y.reshape(shape) x = x * (1 + b * z2) x = x.reshape(shape) return y, b, x
def grid_inv(grid, type='grid', lam=0.1, bound='dft', extrapolate=True): """Invert a dense deformation (or displacement) grid Notes ----- The deformation/displacement grid must be expressed in voxels, and map from/to the same lattice. Let `f = id + d` be the transformation. The inverse is obtained as `id - (k * (f.T @ d)) / (k * (f.T @ 1))` where `k` is a smothing kernel, `f.T @ _` is the adjoint operation ("push") of `f @ _` ("pull"). and `1` is an image of ones. Parameters ---------- grid : (..., *spatial, dim) Transformation (or displacement) grid type : {'grid', 'disp'}, default='grid' Type of deformation. lam : float, default=0.1 Regularisation bound : str, default='dft' extrapolate : bool, default=True Returns ------- grid_inv : (..., *spatial, dim) Inverse transformation (or displacement) grid """ # get shape components dim = grid.shape[-1] shape = grid.shape[-(dim + 1):-1] batch = grid.shape[:-(dim + 1)] grid = grid.reshape([-1, *shape, dim]) backend = dict(dtype=grid.dtype, device=grid.device) # get displacement identity = spatial.identity_grid(shape, **backend) if type == 'grid': disp = grid - identity else: disp = grid grid = disp + identity # push displacement push_opt = dict(bound=bound, extrapolate=extrapolate) disp = core.utils.movedim(disp, -1, 1) disp = spatial.grid_push(disp, grid, **push_opt) count = spatial.grid_count(grid, **push_opt) # Fill missing values using regularised least squares disp = spatial.solve_field_sym(count, disp, membrane=0.1, bound='dft', dim=dim) disp = core.utils.movedim(disp, 1, -1) disp = disp.reshape([*batch, *shape, dim]) if type == 'grid': return identity - disp else: return -disp