def solve(h, g, w, optim): return spatial.solve_field_fmg(h, g, dim=dim, weights=w, **prm, optim=optim)
def solve_prm(self, hess, grad): """Solve Newton step""" hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) delta = spatial.solve_field_fmg(hess, grad, self.rls, **self.lam_prm, voxel_size=self.voxel_size) delta = check_nans_(delta, warn='delta') return delta
def solve_dist(self, hess, grad, vx, readout): """Solve Newton step""" hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) lam = dict(self.lam_dist) lam['factor'] = lam['factor'] * (vx[readout]**2) delta = spatial.solve_field_fmg(hess, grad, **self.lam_dist, dim=3, bound=self.DIST_BOUND, voxel_size=vx) delta = check_nans_(delta, warn='delta') return delta
def solve_parameters(hess, grad, rls, lam, vx, opt): """Solve the regularized linear system Parameters ---------- hess : (2*P+1, *shape) tensor grad : (P+1, *shape) tensor rls : ([P+1], *shape) tensor or None lam : (P,) sequence[float] vx : (D,) sequence[float] opt : Options Returns ------- delta : (P+1, *shape) tensor """ # The ESTATICS Hessian has a very particular form (intercepts do not # have cross elements). We therefore need to tell the solver how to operate # on it. def matvec(m, x): m = m.transpose(-1, -4) x = x.transpose(-1, -4) return hessian_matmul(m, x).transpose(-4, -1) def matsolve(m, x): m = m.transpose(-1, -4) x = x.transpose(-1, -4) return hessian_solve(m, x).transpose(-4, -1) def matdiag(m, d): return m[..., ::2] return spatial.solve_field_fmg(hess, grad, rls, factor=lam, membrane=1, voxel_size=vx, verbose=opt.verbose - 1, nb_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, matvec=matvec, matsolve=matsolve, matdiag=matdiag)
def greeq(data, transmit=None, receive=None, opt=None, **kwopt): """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. transmit : sequence[PrecomputedFieldMap], optional Map(s) of the transmit field (b1+). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. receive : sequence[PrecomputedFieldMap], optional Map(s) of the receive field (b1-). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. If no receive map is provided, the output `pd` map will have a remaining b1- bias field. opt : GREEQOptions or dict, optional Algorithm options. {'preproc': {'register': True}, # Co-register contrasts 'optim': {'nb_levels': 1, # Number of pyramid levels 'max_iter_rls': 10, # Max reweighting iterations 'max_iter_gn': 5, # Max Gauss-Newton iterations 'max_iter_cg': 32, # Max Conjugate-Gradient iterations 'tolerance': 1e-05, # Tolerance for early stopping (RLS) 'tolerance': 1e-05, "" 'tolerance_cg': 1e-03}, "" 'backend': {'dtype': torch.float32, # Data type 'device': 'cpu'}, # Device 'penalty': {'norm': 'jtv', # Type of penalty: {'tkh', 'tv', 'jtv', None} 'factor': {'r1': 10, # Penalty factor per (log) map 'pd': 10, 'r2s': 2, 'mt': 2}}, 'verbose': 1} Returns ------- pd : ParameterMap Proton density r1 : ParameterMap Longitudinal relaxation rate r2s : ParameterMap Apparent transverse relaxation rate mt : ParameterMap, optional Magnetisation transfer saturation Only returned is MT-weighted data is provided. """ opt = GREEQOptions().update(opt, **kwopt) dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) # --- estimate noise / register / initialize maps --- data, transmit, receive, maps = preproc(data, transmit, receive, opt) has_mt = hasattr(maps, 'mt') # --- prepare penalty factor --- lam = opt.penalty.factor if isinstance(lam, dict): lam = [ lam.get('pd', 0), lam.get('r1', 0), lam.get('r2s', 0), lam.get('mt', 0) ] if not has_mt: lam = lam[:3] lam = core.utils.make_vector(lam, 3 + has_mt, **backend) # PD, R1, R2*, [MT] # --- initialize weights (RLS) --- if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0): opt.penalty.norm = '' opt.penalty.norm = opt.penalty.norm.lower() mean_shape = maps[0].shape rls = None sumrls = 0 if opt.penalty.norm in ('tv', 'jtv'): rls_shape = mean_shape if opt.penalty.norm == 'tv': rls_shape = (len(maps), ) + rls_shape rls = torch.ones(rls_shape, **backend) sumrls = 0.5 * core.py.prod(rls_shape) if opt.penalty.norm: print(f'With {opt.penalty.norm.upper()} penalty:') print(f' - PD: {lam[0]:.3g}') print(f' - R1: {lam[1]:.3g}') print(f' - R2*: {lam[2]:.3g}') if has_mt: print(f' - MT: {lam[3]:.3g}') else: print('Without penalty') if opt.penalty.norm not in ('tv', 'jtv'): # no reweighting -> do more gauss-newton updates instead opt.optim.max_iter_gn *= opt.optim.max_iter_rls opt.optim.max_iter_rls = 1 print('Optimization:') print(f' - Tolerance: {opt.optim.tolerance}') if opt.penalty.norm.endswith('tv'): print(f' - IRLS iterations: {opt.optim.max_iter_rls}') print(f' - GN iterations: {opt.optim.max_iter_gn}') if opt.optim.solver == 'fmg': print(f' - FMG cycles: 2') print(f' - CG iterations: 2') else: print(f' - CG iterations: {opt.optim.max_iter_cg}' f' (tolerance: {opt.optim.tolerance_cg})') if opt.optim.nb_levels > 1: print(f' - Levels: {opt.optim.nb_levels}') printer = CritPrinter(max_levels=opt.optim.nb_levels, max_rls=opt.optim.max_iter_rls, max_gn=opt.optim.max_iter_gn, penalty=opt.penalty.norm, verbose=opt.verbose) printer.print_head() shape0 = shape = maps.shape[1:] aff0 = aff = maps.affine vx0 = vx = spatial.voxel_size(aff0) vol0 = vx0.prod() vol = vx.prod() / vol0 ll_scl = sum(core.py.prod(dat.shape) for dat in data) for level in range(opt.optim.nb_levels, 0, -1): printer.level = level if opt.optim.nb_levels > 1: aff, shape = _get_level(level, aff0, shape0) vx = spatial.voxel_size(aff) vol = vx.prod() / vol0 maps, rls = resize(maps, rls, aff, shape) if opt.penalty.norm in ('tv', 'jtv'): sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double) # --- compute derivatives --- nb_prm = len(maps) nb_hes = nb_prm * (nb_prm + 1) // 2 grad = torch.empty((nb_prm, ) + shape, **backend) hess = torch.empty((nb_hes, ) + shape, **backend) ll_rls = [] ll_gn = [] ll_max = float('inf') max_iter_rls = max(opt.optim.max_iter_rls // level, 1) for n_iter_rls in range(max_iter_rls): # --- Reweighted least-squares loop --- printer.rls = n_iter_rls # --- Gauss Newton loop --- for n_iter_gn in range(opt.optim.max_iter_gn): # start = time.time() printer.gn = n_iter_gn crit = 0 grad.zero_() hess.zero_() # --- loop over contrasts --- for contrast, b1m, b1p in zip(data, receive, transmit): crit1, g1, h1 = derivatives_parameters( contrast, maps, b1m, b1p, opt) # increment if hasattr(maps, 'mt') and not contrast.mt: # we optimize for mt but this particular contrast # has no information about mt so g1/h1 are smaller # than grad/hess. grad[:-1] += g1 hind = list(range(nb_prm - 1)) cnt = nb_prm for i in range(nb_prm): for j in range(i + 1, nb_prm): if i != nb_prm - 1 and j != nb_prm - 1: hind.append(cnt) cnt += 1 hess[hind] += h1 crit += crit1 else: grad += g1 hess += h1 crit += crit1 del g1, h1, crit1 # duration = time.time() - start # print('grad', duration) # start = time.time() reg = 0. if opt.penalty.norm: g = spatial.regulariser(maps.volume, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam * vol) grad += g reg = 0.5 * dot(maps.volume, g) del g # duration = time.time() - start # print('reg', duration) # --- gauss-newton --- # start = time.time() grad = check_nans_(grad, warn='gradient') hess = check_nans_(hess, warn='hessian') if opt.penalty.norm: # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8) # 1e-5 1e-8 if opt.optim.solver == 'fmg': deltas = spatial.solve_field_fmg(hess, grad, rls, factor=lam * vol, membrane=1, voxel_size=vx, nb_iter=2) else: deltas = spatial.solve_field( hess, grad, rls, factor=lam * vol, membrane=1, voxel_size=vx, verbose=max(0, opt.verbose - 1), optim='cg', max_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, stop='diff') else: # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4) deltas = spatial.solve_field_closedform(hess, grad) deltas = check_nans_(deltas, warn='deltas') # duration = time.time() - start # print('solve', duration) for map, delta in zip(maps, deltas): map.volume -= delta map.volume.clamp_(-64, 64) # avoid exp overflow del delta del deltas # --- Compute gain --- ll = crit + reg + sumrls ll_max = max(ll_max, ll) ll_prev = ll_gn[-1] if ll_gn else ll_max gain = ll_prev - ll ll_gn.append(ll) printer.print_crit(crit, reg, sumrls, gain / ll_scl) if gain < opt.optim.tolerance * ll_scl: # print('GN converged: ', ll_prev.item(), '->', ll.item()) break # --- Update RLS weights --- if opt.penalty.norm in ('tv', 'jtv'): rls, sumrls = spatial.membrane_weights( maps.volume, lam, vx, return_sum=True, dim=3, joint=opt.penalty.norm == 'jtv', eps=core.constants.eps(rls.dtype)) sumrls *= 0.5 * vol # --- Compute gain --- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll = ll_gn[-1] ll_prev = ll_rls[-1] if ll_rls else ll_max ll_rls.append(ll) gain = ll_prev - ll if abs(gain) < opt.optim.tolerance * ll_scl: # print(f'RLS converged ({gain:7.2g})') break del grad if opt.uncertainty: uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt) maps.pd.uncertainty = uncertainty[0] maps.r1.uncertainty = uncertainty[1] maps.r2s.uncertainty = uncertainty[2] if hasattr(maps, 'mt'): maps.mt.uncertainty = uncertainty[3] # --- Prepare output --- return postproc(maps)
def zcorrect_exp_const(x, decay=None, sigma=None, lam=10, mask=None, max_iter=128, tol=1e-6, verbose=False, snr=5): """Correct the z signal decay in a SPIM image. The signal is modelled as: f(z) = s * exp(-b * z) + eps where z=0 is (arbitrarily) the middle slice, s is the intercept and b is the decay coefficient. Parameters ---------- x : (..., nz) tensor SPIM image with the z dimension last and the z=0 plane first decay : float, optional Initial guess for decay parameter. Default: educated guess. sigma : float, optional Noise standard deviation. Default: educated guess. lam : float or (float, float), default=10 Regularisation. max_iter : int, default=128 tol : float, default=1e-6 verbose : int or bool, default=False Returns ------- y : tensor Corrected image decay : float Decay parameters """ x = torch.as_tensor(x) if not x.dtype.is_floating_point: x = x.to(dtype=torch.get_default_dtype()) backend = utils.backend(x) shape = x.shape dim = x.dim() - 1 nz = shape[-1] b = decay x = utils.movedim(x, -1, 0).clone() if mask is None: mask = torch.isfinite(x) & (x > 0) else: mask = mask & (torch.isfinite(x) & (x > 0)) x[~mask] = 0 # decay educated guess: closed form from two values if b is None: z1 = 2 * nz // 5 z2 = 3 * nz // 5 x1 = x[z1] x1 = x1[x1 > 0].median() x2 = x[z2] x2 = x2[x2 > 0].median() z1 = float(z1) z2 = float(z2) b = (x2.log() - x1.log()) / (z1 - z2) y = x[(nz - 1) // 2] y = y[y > 0].median().log() b = b.item() if torch.is_tensor(b) else b y = y.item() print(f'init: y = {y}, b = {b}') # noise educated guess: assume SNR=5 at z=1/2 sigma = sigma or (y / snr) lam_y, lam_b = py.make_list(lam, 2) lam_y = lam_y**2 * sigma**2 lam_b = lam_b**2 * sigma**2 reg = lambda t: spatial.regulariser( t, membrane=1, dim=dim, factor=(lam_y, lam_b)) solve = lambda h, g: spatial.solve_field_fmg( h, g, membrane=1, dim=dim, factor=(lam_y, lam_b)) # init z = torch.arange(nz, **backend) - (nz - 1) / 2 z = utils.unsqueeze(z, -1, dim) theta = z.new_empty([2, *x.shape[1:]], **backend) logy = theta[0].fill_(y) b = theta[1].fill_(b) y = logy.exp() ll0 = (mask * y * (-b * z).exp_() - x).square_().sum() + (theta * reg(theta)).sum() ll1 = ll0 g = torch.zeros_like(theta) h = theta.new_zeros([3, *theta.shape[1:]]) for it in range(max_iter): # exponentiate y = torch.exp(logy, out=y) fit = (b * z).neg_().exp_().mul_(y).mul_(mask) res = fit - x # compute objective reg_theta = reg(theta) ll = res.square().sum() + (theta * reg_theta).sum() gain = (ll1 - ll) / ll0 if verbose: end = '\n' if verbose > 1 else '\r' print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end) if it > 0 and gain < tol: break ll1 = ll g[0] = (fit * res).sum(0) g[1] = -(fit * res * z).sum(0) h[0] = (fit * (fit + res.abs())).sum(0) h[1] = (fit * (fit + res.abs()) * (z * z)).sum(0) h[2] = -(z * fit * fit).sum(0) g += reg_theta theta -= solve(h, g) y = torch.exp(logy, out=y) x = x * (b * z).exp_() x = utils.movedim(x, 0, -1) x = x.reshape(shape) return y, b, x
def phase_fit(magnitude, phase, lam=(0, 1e1), penalty=('membrane', 'bending')): """Fit a complex image using a decreasing phase regularization Parameters ---------- magnitude : tensor phase : tensor lam : (float, float) penalty : (str, str) Returns ------- magnitude : tensor phase : tensor """ # estimate noise precision sd = estimate_noise(magnitude)[0]['sd'] prec = 1 / (sd * sd) # initialize fit fit = magnitude.new_empty([2, *magnitude.shape]) fit[0] = magnitude fit[0].clamp_min_(1e-8).log_() fit[1] = mean_phase(phase, magnitude) # allocate placeholders g = magnitude.new_empty([2, *magnitude.shape]) h = magnitude.new_empty([2, *magnitude.shape]) n = magnitude.numel() # prepare regularizer options prm = dict( membrane=[lam[0] * int(penalty[0] == 'membrane'), lam[1] * int(penalty[1] == 'membrane')], bending=[lam[0] * int(penalty[0] == 'bending'), lam[1] * int(penalty[1] == 'bending')], bound='dct2') lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1]) ll0 = lr0 = factor = float('inf') for n_iter in range(20): # decrease regularization factor, factor_prev = 1 + 10 ** (5 - n_iter), factor factor_ratio = factor / factor_prev if n_iter else float('inf') prm['membrane'][-1] = lam0['membrane'] * factor prm['bending'][-1] = lam0['bending'] * factor # compute derivatives ll, g, h = derivatives(magnitude, phase, fit[0], fit[1], g, h) ll *= prec g *= prec h *= prec # compute regularization reg = spatial.regulariser(fit, **prm) lr = 0.5 * dot(fit, reg) g += reg # Gauss-Newton step fit -= spatial.solve_field_fmg(h, g, **prm) # Compute progress l0 = ll0 + factor_ratio * lr0 l = ll + lr gain = l0 - l print(f'{n_iter:3d} | {ll/n:12.6g} + {lr/n:6.3g} = {l/n:12.6g} | gain = {gain/n:6.3}') if abs(gain) < n * 1e-4: break ll0, lr0 = ll, lr # plot_fit(magnitude, phase, fit) return fit[0].exp_(), fit[1]
def meetup(data, dist=None, opt=None): """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. dist : sequence[Optional[ParameterizedDistortion]], optional Pre-computed distortion fields opt : Options, optional Algorithm options. Returns ------- intecepts : sequence[GradientEcho] Echo series extrapolated to TE=0 decay : estatics.ParameterMap R2* decay map distortions : sequence[ParameterizedDistortion] B0-induced distortion fields """ # --- prepare everything ------------------------------------------- data, maps, dist, opt, rls = _prepare(data, dist, opt) sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0 backend = dict(dtype=opt.backend.dtype, device=opt.backend.device) # --- initialize tracking of the objective function ---------------- crit, vreg, reg = float('inf'), 0, 0 ll_rls = crit + vreg + reg sumrls_prev = sumrls rls_changed = False ll_scl = sum(core.py.prod(dat.shape) for dat in data) # --- Multi-Resolution loop ---------------------------------------- shape0 = shape = maps.shape[1:] aff0 = aff = maps.affine lam0 = lam = opt.regularization.factor vx0 = vx = spatial.voxel_size(aff0) scl0 = vx0.prod() armijo_prev = 1 for level in range(opt.optim.nb_levels, 0, -1): if opt.optim.nb_levels > 1: aff, shape = _get_level(level, aff0, shape0) vx = spatial.voxel_size(aff) scl = vx.prod() / scl0 lam = [float(l * scl) for l in lam0] maps, rls = _resize(maps, rls, aff, shape) if opt.regularization.norm in ('tv', 'jtv'): sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double) grad = torch.empty((len(data) + 1, *shape), **backend) hess = torch.empty((len(data) * 2 + 1, *shape), **backend) # --- RLS loop ------------------------------------------------- max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls for n_iter_rls in range(1, max_iter_rls + 1): # --- helpers ---------------------------------------------- regularizer_prm = lambda x: spatial.regulariser( x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam) solve_prm = lambda H, g: spatial.solve_field( H, g, rls, factor=lam, membrane=1 if opt.regularization.norm else 0, voxel_size=vx, max_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, dim=3) reweight = lambda x, **k: spatial.membrane_weights( x, lam, vx, dim=3, **k, joint=opt.regularization.norm == 'jtv', eps=core.constants.eps(rls.dtype)) # ---------------------------------------------------------- # Initial update of parameter maps # ---------------------------------------------------------- crit_pre_prm = None max_iter_prm = opt.optim.max_iter_prm for n_iter_prm in range(1, max_iter_prm + 1): crit_prev = crit reg_prev = reg # --- loop over contrasts ------------------------------ crit = 0 grad.zero_() hess.zero_() for i, (contrast, intercept, distortion) \ in enumerate(zip(data, maps.intercepts, dist)): # compute gradient crit1, g1, h1 = derivatives_parameters( contrast, distortion, intercept, maps.decay, opt) # increment gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i] grad[gind] += g1 hess[hind] += h1 crit += crit1 del g1, h1 # --- regularization ----------------------------------- reg = 0. if opt.regularization.norm: g1 = regularizer_prm(maps.volume) reg = 0.5 * dot(maps.volume, g1) grad += g1 del g1 if n_iter_prm == 1: crit_pre_prm = crit # --- track RLS improvement ---------------------------- # Updating the RLS weights changes `sumrls` and `reg`. Now # that we have the updated value of `reg` (before it is # modified by the map update), we can check that updating the # weights indeed improved the objective. if opt.verbose and rls_changed: rls_changed = False if reg + sumrls <= reg_prev + sumrls_prev: evol = '<=' else: evol = '>' ll_tmp = crit_prev + reg + vreg + sumrls pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | ' f'{crit_prev:12.6g} + {reg:12.6g} + ' f'{sumrls:12.6g} + {vreg:12.6g} ' f'= {ll_tmp:12.6g} | {evol}') if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) # --- gauss-newton ------------------------------------- # Computing the GN step involves solving H\g hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) deltas = solve_prm(hess, grad) deltas = check_nans_(deltas, warn='delta') dd = regularizer_prm(deltas) dv = dot(dd, maps.volume) dd = dot(dd, deltas) delta_reg = 0.5 * (dd - 2 * dv) for map, delta in zip(maps, deltas): map.volume -= delta if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) del deltas # --- track parameter map improvement -------------- gain = (crit_prev + reg_prev) - (crit + reg) if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl: break # ---------------------------------------------------------- # Distortion update with line search # ---------------------------------------------------------- max_iter_dist = opt.optim.max_iter_dist for n_iter_dist in range(1, max_iter_dist + 1): crit = 0 vreg = 0 new_crit = 0 new_vreg = 0 deltas, dd, dv = [], 0, 0 # --- loop over contrasts ------------------------------ for i, (contrast, intercept, distortion) \ in enumerate(zip(data, maps.intercepts, dist)): # --- helpers -------------------------------------- vxr = distortion.voxel_size[contrast.readout] lam_dist = dict(opt.distortion.factor) lam_dist['factor'] *= vxr**2 regularizer_dist = lambda x: spatial.regulariser( x[None], **lam_dist, dim=3, bound=DIST_BOUND, voxel_size=distortion.voxel_size)[0] solve_dist = lambda h, g: spatial.solve_field_fmg( h, g, **lam_dist, dim=3, bound=DIST_BOUND, voxel_size=distortion.voxel_size) # --- likelihood ----------------------------------- crit1, g, h = derivatives_distortion( contrast, distortion, intercept, maps.decay, opt) crit += crit1 # --- regularization ------------------------------- vol = distortion.volume g1 = regularizer_dist(vol) vreg1 = 0.5 * vol.flatten().dot(g1.flatten()) vreg += vreg1 g += g1 del g1 # --- gauss-newton --------------------------------- h = check_nans_(h, warn='hessian (distortion)') h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0] delta = solve_dist(h, g) delta = check_nans_(delta, warn='delta (distortion)') deltas.append(delta) del g, h deltas.append(delta) dd1 = regularizer_dist(delta) dv += dot(dd1, vol) dd1 = dot(dd1, delta) dd += dd1 del delta, vol # --- track parameters improvement --------------------- if opt.verbose and n_iter_dist == 1: gain = crit_pre_prm - (crit + delta_reg) evol = '<=' if gain > 0 else '>' ll_tmp = crit + reg + vreg + sumrls pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | ' f'{crit:12.6g} + {reg + delta_reg:12.6g} + ' f'{sumrls:12.6g} + {vreg:12.6g} ' f'= {ll_tmp:12.6g} | ' f'gain = {gain / ll_scl:7.2g} | {evol}') if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) # --- line search ---------------------------------- reg = reg + delta_reg new_crit, new_reg = crit, reg armijo, armijo_prev, ok = armijo_prev, 0, False maps0 = maps for n_ls in range(1, 12 + 1): for delta1, dist1 in zip(deltas, dist): dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev)) armijo_prev = armijo new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv) maps = maps0.deepcopy() max_iter_prm = opt.optim.max_iter_prm for n_iter_prm in range(1, max_iter_prm + 1): crit_prev = new_crit reg_prev = new_reg # --- loop over contrasts ------------------ new_crit = 0 grad.zero_() hess.zero_() for i, (contrast, intercept, distortion) in enumerate( zip(data, maps.intercepts, dist)): # compute gradient new_crit1, g1, h1 = derivatives_parameters( contrast, distortion, intercept, maps.decay, opt) # increment gind, hind = [i, -1 ], [i, len(grad) - 1, len(grad) + i] grad[gind] += g1 hess[hind] += h1 new_crit += new_crit1 del g1, h1 # --- regularization ----------------------- new_reg = 0. if opt.regularization.norm: g1 = regularizer_prm(maps.volume) new_reg = 0.5 * dot(maps.volume, g1) grad += g1 del g1 new_gain = (crit_prev + reg_prev) - (new_crit + new_reg) if new_gain < opt.optim.tolerance * ll_scl: break # --- gauss-newton ------------------------- # Computing the GN step involves solving H\g hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) delta = solve_prm(hess, grad) delta = check_nans_(delta, warn='delta') dd = regularizer_prm(delta) dv = dot(dd, maps.volume) dd = dot(dd, delta) delta_reg = 0.5 * (dd - 2 * dv) for map, delta1 in zip(maps, delta): map.volume -= delta1 if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) del delta if new_crit + new_reg + new_vreg <= crit + reg: ok = True break else: armijo = armijo / 2 if not ok: for delta1, dist1 in zip(deltas, dist): dist1.volume.add_(delta1, alpha=armijo_prev) armijo_prev = 1 maps = maps0 new_crit = crit new_vreg = 0 del delta else: armijo_prev *= 1.5 new_vreg = vreg + new_vreg # --- track distortion improvement --------------------- if not ok: evol = '== (x)' elif new_crit + new_vreg <= crit + vreg: evol = f'<= ({n_ls:2d})' else: evol = f'> ({n_ls:2d})' gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg) crit, reg, reg = new_crit, new_vreg, new_reg if opt.verbose: ll_tmp = crit + reg + vreg + sumrls pstr = ( f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | ' f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ') pstr += f'+ {vreg:12.6g} ' pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | ' pstr += f'{evol}' if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) if opt.plot: _show_maps(maps, dist, data) if not ok: break if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl: break # -------------------------------------------------------------- # Update RLS weights # -------------------------------------------------------------- if level == 1 and opt.regularization.norm in ('tv', 'jtv'): rls_changed = True sumrls_prev = sumrls rls, sumrls = reweight(maps.volume, return_sum=True) sumrls = 0.5 * sumrls # --- compute gain ----------------------------------------- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll_rls_prev = ll_rls ll_rls = crit + reg + vreg gain = ll_rls_prev - ll_rls if gain < opt.optim.tolerance * ll_scl: print(f'RLS converged ({gain / ll_scl:7.2g})') break # --- prepare output ----------------------------------------------- out = postproc(maps, data) if opt.distortion.enable: out = (*out, dist) return out
def sin_b1(x, fa, lam=(0, 1e3), penalty=('m', 'b'), chi=True, pd=None, b1=None): """Estimate B1+ from Variable flip angle data with long TR Parameters ---------- x : (C, *spatial) tensor Input flash images fa : (C,) sequence[float] Flip angle (in deg) lam : (float, float), default=(0, 1e4) Regularization value for the T2*w Signal, T1 map and B1 map. penalty : 3x {'membrane', 'bending'}, default=('m', 'b') Regularization type for the T2*w Signal, T1 map and B1 map. Returns ------- s : (*spatial) tensor T2*-weighted PD map b1 : (*spatial) tenmsor B1+ map """ fa = utils.make_vector(fa, len(x), dtype=torch.double) fa = fa.mul_(pymath.pi / 180).tolist() lam = py.make_list(lam, 2) penalty = py.make_list(penalty, 2) sd, df, mu = 0, 0, 0 for x1 in x: bg, fg = estimate_noise(x1, chi=True) sd += bg['sd'].log() df += bg['dof'].log() mu += fg['mean'].log() sd = (sd / len(x)).exp() df = (df / len(x)).exp() mu = (mu / len(x)).exp() prec = 1 / (sd * sd) if not chi: df = 1 # mask low SNR voxels # x = x * (x > 5 * sd) shape = x.shape[1:] theta = x.new_empty([2, *shape]) theta[0] = mu.log() if pd is None else pd.log() theta[1] = 0 if b1 is None else b1.log() n = (x != 0).sum() g = torch.zeros_like(theta) h = theta.new_zeros([3, *theta.shape[1:]]) g1 = torch.zeros_like(theta) h1 = theta.new_zeros([3, *theta.shape[1:]]) prm = dict( membrane=[ lam[0] if penalty[0][0] == 'm' else 0, lam[1] if penalty[1][0] == 'm' else 0 ], bending=[ lam[0] if penalty[0][0] == 'b' else 0, lam[1] if penalty[1][0] == 'b' else 0 ], ) lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1]) ll0 = lr0 = factor = float('inf') for n_iter in range(32): if n_iter == 1: ll0 = lr0 = float('inf') # decrease regularization factor, factor_prev = 1 + 10**(5 - n_iter), factor factor_ratio = factor / factor_prev if n_iter else float('inf') prm['membrane'][-1] = lam0['membrane'] * factor prm['bending'][-1] = lam0['bending'] * factor # derivatives of likelihood term df1 = 1 if n_iter == 0 else df ll, g, h = sin_full_derivatives(x, fa, theta[0], theta[1], prec, df1, g, h, g1, h1) # derivatives of regularization term reg = spatial.regulariser(theta, **prm) g += reg lr = 0.5 * dot(theta, reg) l, l0 = ll + lr, ll0 + factor_ratio * lr0 gain = l0 - l print( f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}' ) # Gauss-Newton update h[:2] += 1e-8 * h[:2].abs().max(0).values h[:2] += 1e-5 delta = spatial.solve_field_fmg(h, g, **prm) mx = delta.abs().max() if mx > 64: delta *= 64 / mx # theta -= delta # ll0, lr0 = ll, lr # line search dd = spatial.regulariser(delta, **prm) dt = dot(dd, theta) dd = dot(dd, delta) success = False armijo = 1 theta0 = theta ll0, lr0 = ll, lr for n_ls in range(12): theta = theta0.sub(delta, alpha=armijo) ll = sin_nll(x, fa, theta[0], theta[1], prec, df1) lr = 0.5 * armijo * (armijo * dd - 2 * dt) if ll + lr < ll0: # and theta[1].max() < 0.69: print(n_ls, 'success', ((ll + lr) / n).item(), (ll0 / n).item()) success = True break print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item()) armijo /= 2 if not success and n_iter > 5: theta = theta0 break import matplotlib.pyplot as plt # plt.subplot(1, 2, 1) # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp()) # plt.colorbar() # plt.subplot(1, 2, 2) # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp()) # plt.colorbar() # plt.show() plt.rcParams["figure.figsize"] = (4, len(x)) vmin, vmax = 0, 2 * mu y = sin_signals(fa, *theta) ex = 3 for i in range(len(x)): plt.subplot(len(x) + ex, 4, 4 * i + 1) plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax) plt.axis('off') plt.subplot(len(x) + ex, 4, 4 * i + 2) plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax) plt.axis('off') plt.subplot(len(x) + ex, 4, 4 * i + 3) plt.imshow(x[i, ..., x.shape[-1] // 2] - y[i, ..., y.shape[1] // 2], cmap=plt.get_cmap('coolwarm')) plt.axis('off') plt.colorbar() plt.subplot(len(x) + ex, 4, 4 * i + 4) plt.imshow( (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi) plt.axis('off') plt.colorbar() all_fa = torch.linspace(0, 2 * pymath.pi, 512) loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2), (2 * theta.shape[1] // 3, theta.shape[2] // 2, theta.shape[3] // 2), (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)] for j, (nx, ny, nz) in enumerate(loc): plt.subplot(len(x) + ex, 1, len(x) + j + 1) plt.plot(all_fa, sin_signals(all_fa, *theta[:, nx, ny, nz])) plt.scatter(fa, x[:, nx, ny, nz]) plt.show() # if gain < 1e-4 * n: # break theta.exp_() return theta[0], theta[1]