def _nonlin_rls(maps, lam=1., norm='jtv'): """Update the (L1) weights. Parameters ---------- map : (P, *shape) ParameterMaps Parameter map lam : float or (P,) sequence[float], default=1 Regularisation factor norm : {'tv', 'jtv'}, default='jtv' Returns ------- rls : ([P], *shape) tensor Weights from the reweighted least squares scheme """ if norm not in ('tv', 'jtv', '__internal__'): return None if isinstance(maps, ParameterMap): # single map # this should only be an internal call # -> we return the squared gradient map assert norm == '__internal__' vx = spatial.voxel_size(maps.affine) grad_fwd = spatial.diff(maps.fdata(), dim=[0, 1, 2], voxel_size=vx, side='f') grad_bwd = spatial.diff(maps.fdata(), dim=[0, 1, 2], voxel_size=vx, side='b') grad = grad_fwd.square_().sum(-1) grad += grad_bwd.square_().sum(-1) grad *= lam / 2. # average across sides (2) return grad # multiple maps if norm == 'tv': rls = [] for map, l in zip(maps, lam): rls1 = _nonlin_rls(map, l, '__internal__') rls1 = rls1.sqrt_() rls.append(rls1) return torch.stack(rls, dim=0) else: assert norm == 'jtv' rls = 0 for map, l in zip(maps, lam): rls += _nonlin_rls(map, l, '__internal__') rls = rls.sqrt_() return rls
def reg(tensor, vx=1., rls=None, lam=1., do_grad=True): """Compute the gradient of the regularisation term. The regularisation term has the form: `0.5 * lam * sum(w[i] * (g+[i]**2 + g-[i]**2) / 2)` where `i` indexes a voxel, `lam` is the regularisation factor, `w[i]` is the RLS weight, `g+` and `g-` are the forward and backward spatial gradients of the parameter map. Parameters ---------- tensor : (K, *shape) tensor Parameter map vx : float or sequence[float], default=1 Voxel size rls : (K|1, *shape) tensor, optional Weights from the reweighted least squares scheme lam : float or sequence[float], default=1 Regularisation factor do_grad : bool, default=True Return both the criterion and gradient Returns ------- reg : () tensor[double] Regularisation term grad : (K, *shape) tensor Gradient with respect to the parameter map """ nb_prm = tensor.shape[0] backend = dict(dtype=tensor.dtype, device=tensor.device) vx = core.utils.make_vector(vx, 3, **backend) lam = core.utils.make_vector(lam, nb_prm, **backend) grad_fwd = spatial.diff(tensor, dim=[1, 2, 3], voxel_size=vx, side='f') grad_bwd = spatial.diff(tensor, dim=[1, 2, 3], voxel_size=vx, side='b') if rls is not None: grad_fwd *= rls[..., None] grad_bwd *= rls[..., None] grad_fwd = spatial.div(grad_fwd, dim=[1, 2, 3], voxel_size=vx, side='f') grad_bwd = spatial.div(grad_bwd, dim=[1, 2, 3], voxel_size=vx, side='b') grad = grad_fwd grad += grad_bwd grad *= lam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) / 2. # ^ average across side if do_grad: reg = (tensor * grad).sum(dtype=torch.double) return 0.5 * reg, grad else: grad *= tensor return 0.5 * grad.sum(dtype=torch.double)
def reg1(tensor, vx=1., rls=None, lam=1., do_grad=True): """Compute the gradient of the regularisation term. The regularisation term has the form: `0.5 * lam * sum(w[i] * (g+[i]**2 + g-[i]**2) / 2)` where `i` indexes a voxel, `lam` is the regularisation factor, `w[i]` is the RLS weight, `g+` and `g-` are the forward and backward spatial gradients of the parameter map. Parameters ---------- tensor : (*shape) tensor Parameter map vx : float or sequence[float], default=1 Voxel size rls : (*shape) tensor, optional Weights from the reweighted least squares scheme lam : float, default=1 Regularisation factor do_grad : bool, default=True Return both the criterion and gradient Returns ------- reg : () tensor[double] Regularisation term grad : (*shape) tensor Gradient with respect to the parameter map """ grad_fwd = spatial.diff(tensor, dim=[0, 1, 2], voxel_size=vx, side='f') grad_bwd = spatial.diff(tensor, dim=[0, 1, 2], voxel_size=vx, side='b') if rls is not None: grad_fwd *= rls[..., None] grad_bwd *= rls[..., None] grad_fwd = spatial.div(grad_fwd, dim=[0, 1, 2], voxel_size=vx, side='f') grad_bwd = spatial.div(grad_bwd, dim=[0, 1, 2], voxel_size=vx, side='b') grad = grad_fwd grad += grad_bwd grad *= lam / 2. # average across directions (3) and side (2) if do_grad: reg = (tensor * grad).sum(dtype=torch.double) return 0.5 * reg, grad else: grad *= tensor return 0.5 * grad.sum(dtype=torch.double)
def _loss_ssqd_jtv(dat_x, dat_y, tau, lam, voxel_size=1, side='f', bound='dct2'): """Computes an image denoising loss function, where: * fidelity term: sum-of-squared differences (SSQD) * regularisation term: joint total variation (JTV) * hyper-parameters: tau, lambda Parameters ---------- dat_x : (dmx, dmy, dmz, nchannels) tensor Input image dat_y : (dmx, dmy, dmz, nchannels) tensor Reconstruction image tau : (nchannels) tensor Channel-specific noise precisions lam : (nchannels) tensor Channel-specific regularisation values voxel_size : float or sequence[float], default=1 Unit size used in the denominator of the gradient. side : {'c', 'f', 'b'}, default='f' * 'c': central finite differences * 'f': forward finite differences * 'b': backward finite differences bound : {'dct2', 'dct1', 'dst2', 'dst1', 'dft', 'repeat', 'zero'}, default='dct2' Boundary condition. Returns ---------- nll_yx : tensor Loss function value (negative log-posterior) """ # compute negative log-likelihood (SSQD fidelity term) nll_xy = 0.5 * torch.sum(tau * torch.sum( (dat_x - dat_y)**2, dim=(0, 1, 2))) # compute gradients of reconstruction, shape=(dmx, dmy, dmz, nchannels, dmgr) nll_y = diff(dat_y, order=1, dim=(0, 1, 2), voxel_size=voxel_size, side=side, bound=bound) # modulate channels with regularisation nll_y = lam[None, None, None, :, None] * nll_y # compute negative log-prior (JTV regularisation term) nll_y = torch.sum( nll_y**2 + eps(), dim=-1) # to gradient magnitudes (sum over gradient directions) nll_y = torch.sum(nll_y, dim=-1) # sum over reconstruction channels nll_y = torch.sqrt(nll_y) nll_y = torch.sum(nll_y) # sum over voxels # compute negative log-posterior (loss function) nll_yx = nll_xy + nll_y return nll_yx
def grad(self): """Compute the image gradients in each voxel. Almost equivalent to `self.pull_grad(identity)`. Returns ------- grad : ([C], *spatial, dim) tensor """ return spatial.diff(self.dat, dim=list(range(-self.dim, 0)), bound=self.bound)
def derivatives_distortion(contrast, distortion, intercept, decay, opt, do_grad=True): """Compute the gradient and Hessian of the distortion field. Parameters ---------- contrast : (nb_echo, *obs_shape) GradientEchoMulti A single echo series (with the same weighting) distortion : ParameterizedDeformation A model of distortions caused by B0 inhomogeneities. intercept : (*recon_shape) ParameterMap Log-intercept of the contrast decay : (*recon_shape) ParameterMap Exponential decay opt : Options Returns ------- crit : () tensor Log-likelihood grad : (*shape, 3) tensor hess : (*shape, 6) tensor """ dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) obs_shape = contrast.volume.shape[1:] recon_shape = intercept.volume.shape aff = core.linalg.lmdiv(intercept.affine, contrast.affine) aff = aff.to(**backend) lam = 1 / contrast.noise df = contrast.dof chi = opt.likelihood[0].lower() == 'c' # pull parameter maps to observed space grid = smart_grid(aff, obs_shape, recon_shape) inter = smart_pull(intercept.fdata(**backend), grid) slope = smart_pull(decay.fdata(**backend), grid) readout = contrast.readout if opt.distortion.te_scaling != 'pre': grid_up, grid_down = distortion.exp2( add_identity=not opt.distortion.te_scaling) else: grid_up = grid_down = None crit = 0 grad = torch.zeros(obs_shape + (3, ), **backend) if do_grad else None hess = torch.zeros(obs_shape + (6, ), **backend) if do_grad else None te0 = 0 for e, echo in enumerate(contrast): te = echo.te te0 = te0 or te blip = echo.blip or (2 * (e % 2) - 1) grid_blip = grid_up if blip > 0 else grid_down vscl = te / te0 if opt.distortion.te_scaling == 'pre': vexp = distortion.iexp if blip < 0 else distortion.exp grid_blip = vexp(add_identity=True, alpha=vscl) elif opt.distortion.te_scaling: grid_blip = spatial.add_identity_grid_(vscl * grid_blip) # compute residuals dat = echo.fdata(**backend, rand=True, cache=False) # observed fit = recon_fit(inter, slope, te) # fitted if do_grad and isinstance(distortion, DenseDeformation): # D(fit) o phi gfit = smart_grad(fit, grid_blip, bound='dft', extrapolate=True) fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True) msk = get_mask_missing(dat, fit) # mask of missing values if do_grad and isinstance(distortion, SVFDeformation): # D(fit o phi) gfit = spatial.diff(fit, bound='dft', dim=[-3, -2, -1]) gfit.masked_fill_(msk.unsqueeze(-1), 0) dat.masked_fill_(msk, 0) fit.masked_fill_(msk, 0) msk = msk.bitwise_not_() if chi: crit1, res = nll_chi(dat, fit, msk, lam, df) else: crit1, res = nll_gauss(dat, fit, msk, lam) del dat, fit, msk crit += crit1 if do_grad: g1 = res.unsqueeze(-1).mul(gfit) h1 = torch.zeros_like(hess) if readout is None: h1[..., :3] = gfit.square() h1[..., 3] = gfit[..., 0] * gfit[..., 1] h1[..., 4] = gfit[..., 0] * gfit[..., 2] h1[..., 5] = gfit[..., 1] * gfit[..., 2] else: h1[..., readout] = gfit[..., readout].square() # propagate backward if isinstance(distortion, SVFDeformation): vel = distortion.volume if opt.distortion.te_scaling == 'pre': vel = ((-vscl) * vel) if blip < 0 else (vscl * vel) elif blip < 0: vel = -vel g1, h1 = spatial.exp_backward(vel, g1, h1, steps=distortion.steps) alpha_g = alpha_h = lam alpha_g = alpha_g * blip if opt.distortion.te_scaling == 'pre': alpha_g = alpha_g * vscl alpha_h = alpha_h * (vscl * vscl) grad.add_(g1, alpha=alpha_g) hess.add_(h1, alpha=alpha_h) if not do_grad: return crit if readout is None: mask_nan_(grad) mask_nan_(hess[:-3], 1e-3) # diagonal mask_nan_(hess[-3:]) # off-diagonal else: grad = grad[..., readout] hess = hess[..., readout] mask_nan_(grad) mask_nan_(hess) return crit, grad, hess
def update_rls(maps, lam=1., norm='jtv'): """Update the (L1) weights. Parameters ---------- map : (P, *shape) ParameterMaps Parameter map lam : float or (P,) sequence[float], default=1 Regularisation factor norm : {'tv', 'jtv'}, default='jtv' Returns ------- rls : ([P], *shape) tensor (Inverted) Weights from the reweighted least squares scheme sumrls : () tensor Sum of the (non-inverted) weights """ # vx = spatial.voxel_size(maps.affine) # return spatial.membrane_weights(maps.volume, dim=3, factor=lam, # joint=(norm == 'jtv'), voxel_size=vx, # return_sum=True) # ---------------------------------------------------------------- # This is the old version of the code, before it got refactored # and generalized in the `spatial` module. # ---------------------------------------------------------------- if norm not in ('tv', 'jtv', '__internal__'): return None if isinstance(maps, ParameterMap): # single map # this should only be an internal call # -> we return the squared gradient map assert norm == '__internal__' vx = spatial.voxel_size(maps.affine) grad_fwd = spatial.diff(maps.volume, dim=[0, 1, 2], voxel_size=vx, side='f') grad_bwd = spatial.diff(maps.volume, dim=[0, 1, 2], voxel_size=vx, side='b') grad = grad_fwd.square_().sum(-1) grad += grad_bwd.square_().sum(-1) grad *= lam / 2 # average across side (2) return grad # multiple maps if norm == 'tv': rls = [] for map, l in zip(maps, lam): rls1 = update_rls(map, l, '__internal__') rls1 = rls1.sqrt_() rls.append(rls1) else: assert norm == 'jtv' rls = 0 for map, l in zip(maps, lam): rls += update_rls(map, l, '__internal__') rls = rls.sqrt_() sumrls = rls.sum(dtype=torch.double) eps = core.constants.eps(rls.dtype) rls = rls.clamp_min_(eps).reciprocal_() return rls, sumrls
def test_adjoint_3d(order, bound, side): u = torch.randn([64, 64, 64, 3], dtype=torch.double) v = torch.randn([64, 64, 64], dtype=torch.double) Lv = diff(v, dim=[0, 1, 2], side=side, order=order, bound=bound) Ku = div(u, dim=[0, 1, 2], side=side, order=order, bound=bound) assert torch.allclose((Lv*u).sum(), (Ku*v).sum())
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf): """Estimates full width at half maximum (FWHM) and noise standard deviation (sd) of a 2D or 3D image. It is assumed that the image has been generated as: dat = Ky + n, where K is Gaussian smoothing with some FWHM and n is additive Gaussian noise. FWHM and n are estimated. Parameters ---------- dat : str or (*spatial) tensor Image data or path to nifti file vx : [sequence of] float, default=1 Voxel size verbose : {0, 1, 2}, default=0 Verbosity level: * 0: No verbosity * 1: Print FWHM and sd to screen * 2: 1 + show mask mn : float, optional Exclude values below mx : float, optional Exclude values above Returns ------- fwhm : (dim,) tensor Estimated FWHM sd : scalar tensor Estimated noise standard deviation. References ---------- ..[1] "Linked independent component analysis for multimodal data fusion." Appendix A Groves AR, Beckmann CF, Smith SM, Woolrich MW. Neuroimage. 2011 Feb 1;54(3):2198-217. """ if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): if vx is None: vx = get_voxel_size(dat.affine) dat = dat.fdata(rand=True, missing=0) dat = torch.as_tensor(dat) dim = dat.dim() if vx is None: vx = 1 vx = utils.make_vector(vx, dim) backend = utils.backend(dat) # Make mask msk = (dat > mn).bitwise_and_(dat <= mx) dat = dat.masked_fill(~msk, 0) # TODO: we should erode the mask so that only voxels whose neighbours # are in the mask are considered when computing gradients. if verbose >= 2: show_slices(msk) # Compute image gradient g = diff(dat, dim=range(dim), side='central', voxel_size=vx, bound='dft').abs_() slicer = (slice(1, -1), ) * dim g = g[(*slicer, None)] g[msk[slicer], :] = 0 g = g.reshape([-1, dim]).sum(0, dtype=torch.double) # Make dat have zero mean dat = dat[slicer] dat = dat[msk[slicer]] x0 = dat - dat.mean() # Compute FWHM fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double) fwhm = fwhm / g if verbose >= 1: print(f'FWHM={fwhm.tolist()}') # Compute noise standard deviation sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0] sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0] sz = 1.0 if dim == 3: sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0] sc = (sx * sy * sz) / dim sc.clamp_min_(1) sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc)) if verbose >= 1: print(f'sd={sd.tolist()}') return fwhm, sd
def __call__(self, vel, grad=False, hess=False, gradmov=False, hessmov=False): # This loop performs the forward pass, and computes # derivatives along the way. dim = vel.shape[-1] pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) in_line_search = not grad and not hess logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # forward if self.kernel is None: self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm, **utils.backend(vel)) grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm) warped = spatial.grid_pull(self.moving, grid, bound='dct2', extrapolate=True) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, self.fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients if grad is not False or hess is not False: if self.mugrad is None: self.mugrad = spatial.diff(self.moving, dim=list(range(-dim, 0)), bound='dct2') if grad is not False: grad = grad.neg_() # "final inverse" to "initial" grad = spatial.grid_push(grad, grid) grad = jg(self.mugrad, grad) if hess is not False: hess = spatial.grid_push(hess, grid) hess = jhj(self.mugrad, hess) # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True) llv = 0.5 * (vel * vgrad).sum() if grad is not False: grad += vgrad del vgrad # print objective llx = llx.item() llv = llv.item() ll = llx + llv if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def compute_grad(dat): med = dat.reshape([dat.shape[0], -1]).median(dim=-1).values med = utils.unsqueeze(med, -1, 3) dat /= 0.5*med dat = spatial.diff(dat, dim=[1, 2, 3]).square().sum(-1) return dat