def solve(h, g, w, optim): return spatial.solve_field(h, g, dim=dim, weights=w, **prm, optim=optim, max_iter=sub_iter, tolerance=sub_tol)
def _chi_fit_tv(dat, sigma=None, df=None, lam=10, max_iter=50, tol=1e-5): noise, tissue = estimate_noise(dat, chi=True) mu = tissue['mean'] sigma = sigma or noise['sd'] df = df or noise['dof'] print(f'sigma = {sigma}, dof = {df}, mu = {mu}') isigma2 = 1 / (sigma * sigma) lam = lam / mu msk = torch.isfinite(dat) if msk.any(): dat = dat.masked_fill(msk.bitwise_not_(), 0) fit = chi_bias_correction(dat, sigma, df)[0] msk = dat > 0 n = msk.sum() w = torch.ones_like(dat) h = w.new_full([1] * dat.dim(), isigma2)[None] ll_prev = float('inf') for n_iter in range(max_iter): w, tv = spatial.membrane_weights(fit[None], factor=lam, return_sum=True) l, delta = nll_chi(dat, fit, msk, isigma2, df) delta += spatial.regulariser(fit[None], membrane=lam, weights=w)[0] delta = spatial.solve_field(h, delta[None], membrane=lam, weights=w)[0] fit.sub_(delta).clamp_min_(1e-8 * sigma) ll = l + tv gain, ll_prev = ll_prev - ll, ll print(f'{n_iter:3d} | {l/n:12.6g} + {tv/n:12.6g} = {ll/n:12.6g} ' f'| gain = {gain/n:6.3}') if abs(gain) < tol * n: break return fit, sigma, df
def greeq(data, transmit=None, receive=None, opt=None, **kwopt): """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. transmit : sequence[PrecomputedFieldMap], optional Map(s) of the transmit field (b1+). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. receive : sequence[PrecomputedFieldMap], optional Map(s) of the receive field (b1-). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. If no receive map is provided, the output `pd` map will have a remaining b1- bias field. opt : GREEQOptions or dict, optional Algorithm options. {'preproc': {'register': True}, # Co-register contrasts 'optim': {'nb_levels': 1, # Number of pyramid levels 'max_iter_rls': 10, # Max reweighting iterations 'max_iter_gn': 5, # Max Gauss-Newton iterations 'max_iter_cg': 32, # Max Conjugate-Gradient iterations 'tolerance': 1e-05, # Tolerance for early stopping (RLS) 'tolerance': 1e-05, "" 'tolerance_cg': 1e-03}, "" 'backend': {'dtype': torch.float32, # Data type 'device': 'cpu'}, # Device 'penalty': {'norm': 'jtv', # Type of penalty: {'tkh', 'tv', 'jtv', None} 'factor': {'r1': 10, # Penalty factor per (log) map 'pd': 10, 'r2s': 2, 'mt': 2}}, 'verbose': 1} Returns ------- pd : ParameterMap Proton density r1 : ParameterMap Longitudinal relaxation rate r2s : ParameterMap Apparent transverse relaxation rate mt : ParameterMap, optional Magnetisation transfer saturation Only returned is MT-weighted data is provided. """ opt = GREEQOptions().update(opt, **kwopt) dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) # --- estimate noise / register / initialize maps --- data, transmit, receive, maps = preproc(data, transmit, receive, opt) has_mt = hasattr(maps, 'mt') # --- prepare penalty factor --- lam = opt.penalty.factor if isinstance(lam, dict): lam = [ lam.get('pd', 0), lam.get('r1', 0), lam.get('r2s', 0), lam.get('mt', 0) ] if not has_mt: lam = lam[:3] lam = core.utils.make_vector(lam, 3 + has_mt, **backend) # PD, R1, R2*, [MT] # --- initialize weights (RLS) --- if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0): opt.penalty.norm = '' opt.penalty.norm = opt.penalty.norm.lower() mean_shape = maps[0].shape rls = None sumrls = 0 if opt.penalty.norm in ('tv', 'jtv'): rls_shape = mean_shape if opt.penalty.norm == 'tv': rls_shape = (len(maps), ) + rls_shape rls = torch.ones(rls_shape, **backend) sumrls = 0.5 * core.py.prod(rls_shape) if opt.penalty.norm: print(f'With {opt.penalty.norm.upper()} penalty:') print(f' - PD: {lam[0]:.3g}') print(f' - R1: {lam[1]:.3g}') print(f' - R2*: {lam[2]:.3g}') if has_mt: print(f' - MT: {lam[3]:.3g}') else: print('Without penalty') if opt.penalty.norm not in ('tv', 'jtv'): # no reweighting -> do more gauss-newton updates instead opt.optim.max_iter_gn *= opt.optim.max_iter_rls opt.optim.max_iter_rls = 1 print('Optimization:') print(f' - Tolerance: {opt.optim.tolerance}') if opt.penalty.norm.endswith('tv'): print(f' - IRLS iterations: {opt.optim.max_iter_rls}') print(f' - GN iterations: {opt.optim.max_iter_gn}') if opt.optim.solver == 'fmg': print(f' - FMG cycles: 2') print(f' - CG iterations: 2') else: print(f' - CG iterations: {opt.optim.max_iter_cg}' f' (tolerance: {opt.optim.tolerance_cg})') if opt.optim.nb_levels > 1: print(f' - Levels: {opt.optim.nb_levels}') printer = CritPrinter(max_levels=opt.optim.nb_levels, max_rls=opt.optim.max_iter_rls, max_gn=opt.optim.max_iter_gn, penalty=opt.penalty.norm, verbose=opt.verbose) printer.print_head() shape0 = shape = maps.shape[1:] aff0 = aff = maps.affine vx0 = vx = spatial.voxel_size(aff0) vol0 = vx0.prod() vol = vx.prod() / vol0 ll_scl = sum(core.py.prod(dat.shape) for dat in data) for level in range(opt.optim.nb_levels, 0, -1): printer.level = level if opt.optim.nb_levels > 1: aff, shape = _get_level(level, aff0, shape0) vx = spatial.voxel_size(aff) vol = vx.prod() / vol0 maps, rls = resize(maps, rls, aff, shape) if opt.penalty.norm in ('tv', 'jtv'): sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double) # --- compute derivatives --- nb_prm = len(maps) nb_hes = nb_prm * (nb_prm + 1) // 2 grad = torch.empty((nb_prm, ) + shape, **backend) hess = torch.empty((nb_hes, ) + shape, **backend) ll_rls = [] ll_gn = [] ll_max = float('inf') max_iter_rls = max(opt.optim.max_iter_rls // level, 1) for n_iter_rls in range(max_iter_rls): # --- Reweighted least-squares loop --- printer.rls = n_iter_rls # --- Gauss Newton loop --- for n_iter_gn in range(opt.optim.max_iter_gn): # start = time.time() printer.gn = n_iter_gn crit = 0 grad.zero_() hess.zero_() # --- loop over contrasts --- for contrast, b1m, b1p in zip(data, receive, transmit): crit1, g1, h1 = derivatives_parameters( contrast, maps, b1m, b1p, opt) # increment if hasattr(maps, 'mt') and not contrast.mt: # we optimize for mt but this particular contrast # has no information about mt so g1/h1 are smaller # than grad/hess. grad[:-1] += g1 hind = list(range(nb_prm - 1)) cnt = nb_prm for i in range(nb_prm): for j in range(i + 1, nb_prm): if i != nb_prm - 1 and j != nb_prm - 1: hind.append(cnt) cnt += 1 hess[hind] += h1 crit += crit1 else: grad += g1 hess += h1 crit += crit1 del g1, h1, crit1 # duration = time.time() - start # print('grad', duration) # start = time.time() reg = 0. if opt.penalty.norm: g = spatial.regulariser(maps.volume, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam * vol) grad += g reg = 0.5 * dot(maps.volume, g) del g # duration = time.time() - start # print('reg', duration) # --- gauss-newton --- # start = time.time() grad = check_nans_(grad, warn='gradient') hess = check_nans_(hess, warn='hessian') if opt.penalty.norm: # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8) # 1e-5 1e-8 if opt.optim.solver == 'fmg': deltas = spatial.solve_field_fmg(hess, grad, rls, factor=lam * vol, membrane=1, voxel_size=vx, nb_iter=2) else: deltas = spatial.solve_field( hess, grad, rls, factor=lam * vol, membrane=1, voxel_size=vx, verbose=max(0, opt.verbose - 1), optim='cg', max_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, stop='diff') else: # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4) deltas = spatial.solve_field_closedform(hess, grad) deltas = check_nans_(deltas, warn='deltas') # duration = time.time() - start # print('solve', duration) for map, delta in zip(maps, deltas): map.volume -= delta map.volume.clamp_(-64, 64) # avoid exp overflow del delta del deltas # --- Compute gain --- ll = crit + reg + sumrls ll_max = max(ll_max, ll) ll_prev = ll_gn[-1] if ll_gn else ll_max gain = ll_prev - ll ll_gn.append(ll) printer.print_crit(crit, reg, sumrls, gain / ll_scl) if gain < opt.optim.tolerance * ll_scl: # print('GN converged: ', ll_prev.item(), '->', ll.item()) break # --- Update RLS weights --- if opt.penalty.norm in ('tv', 'jtv'): rls, sumrls = spatial.membrane_weights( maps.volume, lam, vx, return_sum=True, dim=3, joint=opt.penalty.norm == 'jtv', eps=core.constants.eps(rls.dtype)) sumrls *= 0.5 * vol # --- Compute gain --- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll = ll_gn[-1] ll_prev = ll_rls[-1] if ll_rls else ll_max ll_rls.append(ll) gain = ll_prev - ll if abs(gain) < opt.optim.tolerance * ll_scl: # print(f'RLS converged ({gain:7.2g})') break del grad if opt.uncertainty: uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt) maps.pd.uncertainty = uncertainty[0] maps.r1.uncertainty = uncertainty[1] maps.r2s.uncertainty = uncertainty[2] if hasattr(maps, 'mt'): maps.mt.uncertainty = uncertainty[3] # --- Prepare output --- return postproc(maps)
def run_epic_noreverse(echoes, voxshift=None, lam=1, sigma=1, max_iter=(10, 32), tol=1e-5, verbose=False): """Run EPIC on pre-loaded tensors (no reverse gradients or synthesis) Parameters ---------- echoes : (N, *spatial) tensor Echoes acquired with bipolar readout, Readout direction should be last. lam : [list of] float Regularization factor (per echo) sigma : float Noise standard deviation max_iter : [pair of] int Maximum number of RLS and CG iterations tol : float Tolerance for early stopping verbose : int, Verbosity level Returns ------- echoes : (N, *spatial) tensor Undistorted + denoised echoes """ ne = len(echoes) # number of echoes nv = echoes.shape[1:].numel() # number of voxels nd = echoes.dim() - 1 # number of dimensions # initialize denoised echoes fit = echoes.clone() fwd_fit = torch.empty_like(fit) # prepare voxel shift maps do_voxshift = voxshift is not None if do_voxshift: ivoxshift = add_identity_1d(-voxshift) voxshift = add_identity_1d(voxshift) else: ivoxshift = None # prepare parameters max_iter, sub_iter = py.make_list(max_iter, 2) tol, sub_tol = py.make_list(tol, 2) lam = [l / ne for l in py.make_list(lam, ne)] isigma2 = 1 / (sigma * sigma) # compute hessian once and for all if do_voxshift: one = torch.empty_like(voxshift)[None] h = torch.empty_like(echoes) h[0::2] = push1d(pull1d(one, voxshift), voxshift) h[1::2] = push1d(pull1d(one, ivoxshift), ivoxshift) del one else: h = fit.new_ones([1] * (nd + 1)) h *= isigma2 loss = float('inf') for n_iter in range(max_iter): # update weights w, jtv = membrane_weights(fit, factor=lam, return_sum=True) # gradient of likelihood pull_forward(fit, voxshift, ivoxshift, out=fwd_fit) fwd_fit.sub_(echoes) ll = ssq(fwd_fit) push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit) g = fwd_fit.mul_(isigma2) ll *= 0.5 * isigma2 # gradient of prior g += regulariser(fit, membrane=1, factor=lam, weights=w) # solve fit -= solve_field(h, g, w, membrane=1, factor=lam, max_iter=sub_iter, tolerance=sub_tol) # track objective ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv) loss, loss_prev = ll + jtv, loss if n_iter: gain = (loss_prev - loss) / max((loss_max - loss), 1e-8) else: gain = float('inf') loss_max = loss if verbose: end = '\n' if verbose > 1 else '\r' print( f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} ' f'| gain = {gain:12.6g}', end=end) if gain < tol: break if verbose == 1: print('') return fit
def run_epic(echoes, reverse_echoes=None, voxshift=None, extrapolate=True, lam=1, sigma=1, max_iter=(10, 32), tol=1e-5, verbose=False): """Run EPIC on pre-loaded tensors. Parameters ---------- echoes : (N, *spatial) tensor Echoes acquired with bipolar readout, Readout direction should be last. reverse_echoes : (N, *spatial) Echoes acquired with reverse bipolar readout. Else: synthesized. voxshift : (*spatial) tensor Voxel shift map used to deform towards even (0, 2, ...) echoes. Its inverse is used to deform towards odd (1, 3, ...) echoes. extrapolate : bool Extrapolate first/last echo when reverse_echoes is None. Otherwise, only use interpolated echoes. lam : [list of] float Regularization factor (per echo) sigma : float Noise standard deviation max_iter : [pair of] int Maximum number of RLS and CG iterations tol : float Tolerance for early stopping verbose : int, Verbosity level Returns ------- echoes : (N, *spatial) tensor Undistorted + denoised echoes """ if reverse_echoes is False: return run_epic_noreverse(echoes, voxshift, lam, sigma, max_iter, tol, verbose) ne = len(echoes) # number of echoes nv = echoes.shape[1:].numel() # number of voxels nd = echoes.dim() - 1 # number of dimensions # synthesize echoes synth = not torch.is_tensor(reverse_echoes) if synth: neg = synthesize_neg(echoes[0::2]) pos = synthesize_pos(echoes[1::2]) reverse_echoes = torch.stack([x for y in zip(pos, neg) for x in y]) del pos, neg else: extrapolate = True # initialize denoised echoes fit = (echoes + reverse_echoes).div_(2) if not extrapolate: fit[0] = echoes[0] fit[-1] = echoes[-1] fwd_fit = torch.zeros_like(fit) bwd_fit = torch.zeros_like(fit) # prepare voxel shift maps if voxshift is not None: ivoxshift = add_identity_1d(-voxshift) voxshift = add_identity_1d(voxshift) else: ivoxshift = None # prepare parameters max_iter, sub_iter = py.make_list(max_iter, 2) tol, sub_tol = py.make_list(tol, 2) lam = [l / ne for l in py.make_list(lam, ne)] isigma2 = 1 / (sigma * sigma) # compute hessian once and for all if voxshift is not None: one = torch.ones_like(voxshift)[None] if extrapolate: h = push1d(pull1d(one, voxshift), voxshift) h += push1d(pull1d(one, ivoxshift), ivoxshift) weight_ = lambda x: x.mul_(0.5) halfweight_ = lambda x: x.mul_(math.sqrt(0.5)) else: h = torch.zeros_like(fit) h[:-1] += push1d(pull1d(one, voxshift), voxshift) h[1:] += push1d(pull1d(one, ivoxshift), ivoxshift) weight_ = lambda x: x[1:-1].mul_(0.5) halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5)) del one weight_(h) else: h = fit.new_ones([ne] + [1] * nd) if extrapolate: h *= 2 weight_ = lambda x: x.mul_(0.5) halfweight_ = lambda x: x.mul_(math.sqrt(0.5)) else: h[1:-1] *= 2 weight_ = lambda x: x[1:-1].mul_(0.5) halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5)) weight_(h) h *= isigma2 loss = float('inf') for n_iter in range(max_iter): # update weights w, jtv = membrane_weights(fit, factor=lam, return_sum=True) # gradient of likelihood (forward) pull_forward(fit, voxshift, ivoxshift, out=fwd_fit) fwd_fit.sub_(echoes) halfweight_(fwd_fit) ll = ssq(fwd_fit) halfweight_(fwd_fit) push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit) # gradient of likelihood (reversed) pull_backward(fit, voxshift, ivoxshift, extrapolate, out=bwd_fit) if extrapolate: bwd_fit.sub_(reverse_echoes) else: bwd_fit[1:-1].sub_(reverse_echoes[1:-1]) halfweight_(bwd_fit) ll += ssq(bwd_fit) halfweight_(bwd_fit) push_backward(bwd_fit, voxshift, ivoxshift, extrapolate, out=bwd_fit) g = fwd_fit.add_(bwd_fit).mul_(isigma2) ll *= 0.5 * isigma2 # gradient of prior g += regulariser(fit, membrane=1, factor=lam, weights=w) # solve fit -= solve_field(h, g, w, membrane=1, factor=lam, max_iter=sub_iter, tolerance=sub_tol) # track objective ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv) loss, loss_prev = ll + jtv, loss if n_iter: gain = (loss_prev - loss) / max((loss_max - loss), 1e-8) else: gain = float('inf') loss_max = loss if verbose: end = '\n' if verbose > 1 else '\r' print( f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} ' f'| gain = {gain:12.6g}', end=end) if gain < tol: break if verbose == 1: print('') return fit
def meetup(data, dist=None, opt=None): """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. dist : sequence[Optional[ParameterizedDistortion]], optional Pre-computed distortion fields opt : Options, optional Algorithm options. Returns ------- intecepts : sequence[GradientEcho] Echo series extrapolated to TE=0 decay : estatics.ParameterMap R2* decay map distortions : sequence[ParameterizedDistortion] B0-induced distortion fields """ # --- prepare everything ------------------------------------------- data, maps, dist, opt, rls = _prepare(data, dist, opt) sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0 backend = dict(dtype=opt.backend.dtype, device=opt.backend.device) # --- initialize tracking of the objective function ---------------- crit, vreg, reg = float('inf'), 0, 0 ll_rls = crit + vreg + reg sumrls_prev = sumrls rls_changed = False ll_scl = sum(core.py.prod(dat.shape) for dat in data) # --- Multi-Resolution loop ---------------------------------------- shape0 = shape = maps.shape[1:] aff0 = aff = maps.affine lam0 = lam = opt.regularization.factor vx0 = vx = spatial.voxel_size(aff0) scl0 = vx0.prod() armijo_prev = 1 for level in range(opt.optim.nb_levels, 0, -1): if opt.optim.nb_levels > 1: aff, shape = _get_level(level, aff0, shape0) vx = spatial.voxel_size(aff) scl = vx.prod() / scl0 lam = [float(l * scl) for l in lam0] maps, rls = _resize(maps, rls, aff, shape) if opt.regularization.norm in ('tv', 'jtv'): sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double) grad = torch.empty((len(data) + 1, *shape), **backend) hess = torch.empty((len(data) * 2 + 1, *shape), **backend) # --- RLS loop ------------------------------------------------- max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls for n_iter_rls in range(1, max_iter_rls + 1): # --- helpers ---------------------------------------------- regularizer_prm = lambda x: spatial.regulariser( x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam) solve_prm = lambda H, g: spatial.solve_field( H, g, rls, factor=lam, membrane=1 if opt.regularization.norm else 0, voxel_size=vx, max_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, dim=3) reweight = lambda x, **k: spatial.membrane_weights( x, lam, vx, dim=3, **k, joint=opt.regularization.norm == 'jtv', eps=core.constants.eps(rls.dtype)) # ---------------------------------------------------------- # Initial update of parameter maps # ---------------------------------------------------------- crit_pre_prm = None max_iter_prm = opt.optim.max_iter_prm for n_iter_prm in range(1, max_iter_prm + 1): crit_prev = crit reg_prev = reg # --- loop over contrasts ------------------------------ crit = 0 grad.zero_() hess.zero_() for i, (contrast, intercept, distortion) \ in enumerate(zip(data, maps.intercepts, dist)): # compute gradient crit1, g1, h1 = derivatives_parameters( contrast, distortion, intercept, maps.decay, opt) # increment gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i] grad[gind] += g1 hess[hind] += h1 crit += crit1 del g1, h1 # --- regularization ----------------------------------- reg = 0. if opt.regularization.norm: g1 = regularizer_prm(maps.volume) reg = 0.5 * dot(maps.volume, g1) grad += g1 del g1 if n_iter_prm == 1: crit_pre_prm = crit # --- track RLS improvement ---------------------------- # Updating the RLS weights changes `sumrls` and `reg`. Now # that we have the updated value of `reg` (before it is # modified by the map update), we can check that updating the # weights indeed improved the objective. if opt.verbose and rls_changed: rls_changed = False if reg + sumrls <= reg_prev + sumrls_prev: evol = '<=' else: evol = '>' ll_tmp = crit_prev + reg + vreg + sumrls pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | ' f'{crit_prev:12.6g} + {reg:12.6g} + ' f'{sumrls:12.6g} + {vreg:12.6g} ' f'= {ll_tmp:12.6g} | {evol}') if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) # --- gauss-newton ------------------------------------- # Computing the GN step involves solving H\g hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) deltas = solve_prm(hess, grad) deltas = check_nans_(deltas, warn='delta') dd = regularizer_prm(deltas) dv = dot(dd, maps.volume) dd = dot(dd, deltas) delta_reg = 0.5 * (dd - 2 * dv) for map, delta in zip(maps, deltas): map.volume -= delta if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) del deltas # --- track parameter map improvement -------------- gain = (crit_prev + reg_prev) - (crit + reg) if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl: break # ---------------------------------------------------------- # Distortion update with line search # ---------------------------------------------------------- max_iter_dist = opt.optim.max_iter_dist for n_iter_dist in range(1, max_iter_dist + 1): crit = 0 vreg = 0 new_crit = 0 new_vreg = 0 deltas, dd, dv = [], 0, 0 # --- loop over contrasts ------------------------------ for i, (contrast, intercept, distortion) \ in enumerate(zip(data, maps.intercepts, dist)): # --- helpers -------------------------------------- vxr = distortion.voxel_size[contrast.readout] lam_dist = dict(opt.distortion.factor) lam_dist['factor'] *= vxr**2 regularizer_dist = lambda x: spatial.regulariser( x[None], **lam_dist, dim=3, bound=DIST_BOUND, voxel_size=distortion.voxel_size)[0] solve_dist = lambda h, g: spatial.solve_field_fmg( h, g, **lam_dist, dim=3, bound=DIST_BOUND, voxel_size=distortion.voxel_size) # --- likelihood ----------------------------------- crit1, g, h = derivatives_distortion( contrast, distortion, intercept, maps.decay, opt) crit += crit1 # --- regularization ------------------------------- vol = distortion.volume g1 = regularizer_dist(vol) vreg1 = 0.5 * vol.flatten().dot(g1.flatten()) vreg += vreg1 g += g1 del g1 # --- gauss-newton --------------------------------- h = check_nans_(h, warn='hessian (distortion)') h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0] delta = solve_dist(h, g) delta = check_nans_(delta, warn='delta (distortion)') deltas.append(delta) del g, h deltas.append(delta) dd1 = regularizer_dist(delta) dv += dot(dd1, vol) dd1 = dot(dd1, delta) dd += dd1 del delta, vol # --- track parameters improvement --------------------- if opt.verbose and n_iter_dist == 1: gain = crit_pre_prm - (crit + delta_reg) evol = '<=' if gain > 0 else '>' ll_tmp = crit + reg + vreg + sumrls pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | ' f'{crit:12.6g} + {reg + delta_reg:12.6g} + ' f'{sumrls:12.6g} + {vreg:12.6g} ' f'= {ll_tmp:12.6g} | ' f'gain = {gain / ll_scl:7.2g} | {evol}') if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) # --- line search ---------------------------------- reg = reg + delta_reg new_crit, new_reg = crit, reg armijo, armijo_prev, ok = armijo_prev, 0, False maps0 = maps for n_ls in range(1, 12 + 1): for delta1, dist1 in zip(deltas, dist): dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev)) armijo_prev = armijo new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv) maps = maps0.deepcopy() max_iter_prm = opt.optim.max_iter_prm for n_iter_prm in range(1, max_iter_prm + 1): crit_prev = new_crit reg_prev = new_reg # --- loop over contrasts ------------------ new_crit = 0 grad.zero_() hess.zero_() for i, (contrast, intercept, distortion) in enumerate( zip(data, maps.intercepts, dist)): # compute gradient new_crit1, g1, h1 = derivatives_parameters( contrast, distortion, intercept, maps.decay, opt) # increment gind, hind = [i, -1 ], [i, len(grad) - 1, len(grad) + i] grad[gind] += g1 hess[hind] += h1 new_crit += new_crit1 del g1, h1 # --- regularization ----------------------- new_reg = 0. if opt.regularization.norm: g1 = regularizer_prm(maps.volume) new_reg = 0.5 * dot(maps.volume, g1) grad += g1 del g1 new_gain = (crit_prev + reg_prev) - (new_crit + new_reg) if new_gain < opt.optim.tolerance * ll_scl: break # --- gauss-newton ------------------------- # Computing the GN step involves solving H\g hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) delta = solve_prm(hess, grad) delta = check_nans_(delta, warn='delta') dd = regularizer_prm(delta) dv = dot(dd, maps.volume) dd = dot(dd, delta) delta_reg = 0.5 * (dd - 2 * dv) for map, delta1 in zip(maps, delta): map.volume -= delta1 if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) del delta if new_crit + new_reg + new_vreg <= crit + reg: ok = True break else: armijo = armijo / 2 if not ok: for delta1, dist1 in zip(deltas, dist): dist1.volume.add_(delta1, alpha=armijo_prev) armijo_prev = 1 maps = maps0 new_crit = crit new_vreg = 0 del delta else: armijo_prev *= 1.5 new_vreg = vreg + new_vreg # --- track distortion improvement --------------------- if not ok: evol = '== (x)' elif new_crit + new_vreg <= crit + vreg: evol = f'<= ({n_ls:2d})' else: evol = f'> ({n_ls:2d})' gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg) crit, reg, reg = new_crit, new_vreg, new_reg if opt.verbose: ll_tmp = crit + reg + vreg + sumrls pstr = ( f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | ' f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ') pstr += f'+ {vreg:12.6g} ' pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | ' pstr += f'{evol}' if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) if opt.plot: _show_maps(maps, dist, data) if not ok: break if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl: break # -------------------------------------------------------------- # Update RLS weights # -------------------------------------------------------------- if level == 1 and opt.regularization.norm in ('tv', 'jtv'): rls_changed = True sumrls_prev = sumrls rls, sumrls = reweight(maps.volume, return_sum=True) sumrls = 0.5 * sumrls # --- compute gain ----------------------------------------- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll_rls_prev = ll_rls ll_rls = crit + reg + vreg gain = ll_rls_prev - ll_rls if gain < opt.optim.tolerance * ll_scl: print(f'RLS converged ({gain / ll_scl:7.2g})') break # --- prepare output ----------------------------------------------- out = postproc(maps, data) if opt.distortion.enable: out = (*out, dist) return out
def nonlin(data, dist=None, opt=None): """Fit the ESTATICS model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. dist : sequence[Optional[ParameterizedDistortion]], optional Pre-computed distortion fields opt : Options, optional Algorithm options. Returns ------- intecepts : sequence[GradientEcho] Echo series extrapolated to TE=0 decay : estatics.ParameterMap R2* decay map distortions : sequence[ParameterizedDistortion], if opt.distortion.enable B0-induced distortion fields """ if opt.distortion.enable: return meetup(data, dist, opt) # --- prepare everything ------------------------------------------- data, maps, dist, opt, rls = _prepare(data, dist, opt) sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0 backend = dict(dtype=opt.backend.dtype, device=opt.backend.device) # --- initialize tracking of the objective function ---------------- crit = float('inf') vreg = 0 reg = 0 sumrls_prev = sumrls rls_changed = False ll_gn = ll_rls = float('inf') ll_scl = sum(core.py.prod(dat.shape) for dat in data) # --- Multi-Resolution loop ---------------------------------------- shape0 = shape = maps.shape[1:] aff0 = aff = maps.affine lam0 = lam = opt.regularization.factor vx0 = vx = spatial.voxel_size(aff0) scl0 = vx0.prod() for level in range(opt.optim.nb_levels, 0, -1): if opt.optim.nb_levels > 1: aff, shape = _get_level(level, aff0, shape0) vx = spatial.voxel_size(aff) scl = vx.prod() / scl0 lam = [float(l * scl) for l in lam0] maps, rls = _resize(maps, rls, aff, shape) if opt.regularization.norm in ('tv', 'jtv'): sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double) grad = torch.empty((len(data) + 1, *shape), **backend) hess = torch.empty((len(data) * 2 + 1, *shape), **backend) # --- RLS loop ------------------------------------------------- # > max_iter_rls == 1 if regularization is not (J)TV max_iter_rls = opt.optim.max_iter_rls if level != 1: max_iter_rls = 1 for n_iter_rls in range(1, max_iter_rls + 1): # --- helpers ---------------------------------------------- regularizer = lambda x: spatial.regulariser( x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam) solve = lambda h, g: spatial.solve_field( h, g, rls, factor=lam, membrane=1 if opt.regularization.norm else 0, voxel_size=vx, max_iter=opt.optim.max_iter_cg, tolerance=opt.optim.tolerance_cg, dim=3) reweight = lambda x, **k: spatial.membrane_weights( x, lam, vx, dim=3, **k, joint=opt.regularization.norm == 'jtv', eps=core.constants.eps(rls.dtype)) # ---------------------------------------------------------- # Update parameter maps # ---------------------------------------------------------- max_iter_prm = opt.optim.max_iter_prm for n_iter_prm in range(1, max_iter_prm + 1): crit_prev = crit reg_prev = reg # --- loop over contrasts ------------------------------ crit = 0 grad.zero_() hess.zero_() for i, (contrast, intercept, distortion) \ in enumerate(zip(data, maps.intercepts, dist)): # compute gradient crit1, g1, h1 = derivatives_parameters( contrast, distortion, intercept, maps.decay, opt) # increment gind = [i, -1] grad[gind] += g1 hind = [i, len(grad) - 1, len(grad) + i] hess[hind] += h1 crit += crit1 del g1, h1 # --- regularization ----------------------------------- reg = 0. if opt.regularization.norm: g1 = regularizer(maps.volume) reg = 0.5 * dot(maps.volume, g1) grad += g1 del g1 # --- track RLS improvement ---------------------------- # Updating the RLS weights changes `sumrls` and `reg`. Now # that we have the updated value of `reg` (before it is # modified by the map update), we can check that updating the # weights indeed improved the objective. if opt.verbose and rls_changed: rls_changed = False if reg + sumrls <= reg_prev + sumrls_prev: evol = '<=' else: evol = '>' ll_tmp = crit_prev + reg + vreg + sumrls pstr = ( f'{n_iter_rls:3d} | {"---":3s} | {"rls":4s} | ' f'{crit_prev:12.6g} + {reg:12.6g} + {sumrls:12.6g} ') if opt.distortion.enable: pstr += f'+ {vreg:12.6g} ' pstr += f'= {ll_tmp:12.6g} | ' pstr += f'{evol}' if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) # --- gauss-newton ------------------------------------- # Computing the GN step involves solving H\g hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) deltas = solve(hess, grad) deltas = check_nans_(deltas, warn='delta') for map, delta in zip(maps, deltas): map.volume -= delta if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) del deltas # --- compute GN gain ---------------------------------- ll_gn_prev = ll_gn ll_gn = crit + reg + vreg + sumrls gain = ll_gn_prev - ll_gn if opt.verbose: pstr = f'{n_iter_rls:3d} | {n_iter_prm:3d} | ' if opt.distortion.enable: pstr += f'{"----":4s} | ' pstr += f'{"-"*72:72s} | ' else: pstr += f'{"prm":4s} | ' pstr += f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ' pstr += f'= {ll_gn:12.6g} | ' pstr += f'gain = {gain/ll_scl:7.2g}' if opt.optim.nb_levels > 1: pstr = f'{level:3d} | ' + pstr print(pstr) if opt.plot: _show_maps(maps, dist, data) if gain < opt.optim.tolerance * ll_scl: break # -------------------------------------------------------------- # Update RLS weights # -------------------------------------------------------------- if level == 1 and opt.regularization.norm in ('tv', 'jtv'): reg = 0.5 * dot(maps.volume, regularizer(maps.volume)) rls_changed = True sumrls_prev = sumrls rls, sumrls = reweight(maps.volume, return_sum=True) sumrls = 0.5 * sumrls # --- compute gain ----------------------------------------- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll_rls_prev = ll_rls ll_rls = ll_gn gain = ll_rls_prev - ll_rls if gain < opt.optim.tolerance * ll_scl: print(f'Converged ({gain/ll_scl:7.2g})') break # --- prepare output ----------------------------------------------- out = postproc(maps, data) if opt.distortion.enable: out = (*out, dist) return out
def flash_b1(x, fa, tr, lam=(0, 0, 0), penalty=('m', 'm', 'm'), chi=True, pd=None, r1=None, b1=None): """Estimate B1+ from Variable flip angle data Parameters ---------- x : (C, *spatial) tensor Input flash images fa : (C,) sequence[float] Flip angle (in deg) tr : float or (C,) sequence[float] Repetition time (in sec) lam : (float, float, float), default=(0, 0, 10) Regularization value for the T2*w Signal, T1 map and B1 map. penalty : 3x {'membrane', 'bending'}, default=('m', 'm', 'b') Regularization type for the T2*w Signal, T1 map and B1 map. Returns ------- s : (*spatial) tensor T2*-weighted PD map r1 : (*spatial) tensor R1 map b1 : (*spatial) tenmsor B1+ map """ tr = py.make_list(tr, len(x)) fa = utils.make_vector(fa, len(x), dtype=torch.double) fa = fa.mul_(pymath.pi / 180).tolist() lam = py.make_list(lam, 3) penalty = py.make_list(penalty, 3) sd, df, mu = 0, 0, 0 for x1 in x: bg, fg = estimate_noise(x1, chi=True) sd += bg['sd'].log() df += bg['dof'].log() mu += fg['mean'].log() sd = (sd / len(x)).exp() df = (df / len(x)).exp() mu = (mu / len(x)).exp() prec = 1 / (sd * sd) if not chi: df = 1 shape = x.shape[1:] theta = x.new_empty([3, *shape]) theta[0] = mu.log() if pd is None else pd.log() theta[1] = 1 if r1 is None else r1.log() theta[2] = 0 if b1 is None else b1.log() n = (x != 0).sum() g = torch.zeros_like(theta) h = theta.new_zeros([6, *theta.shape[1:]]) g1 = torch.zeros_like(theta) h1 = theta.new_zeros([6, *theta.shape[1:]]) prm = dict( membrane=(lam[0] if penalty[0][0] == 'm' else 0, lam[1] if penalty[1][0] == 'm' else 0, lam[2] if penalty[2][0] == 'm' else 0), bending=(lam[0] if penalty[0][0] == 'b' else 0, lam[1] if penalty[1][0] == 'b' else 0, lam[2] if penalty[2][0] == 'b' else 0), ) ll0 = lr0 = float('inf') iter_start = 3 for n_iter in range(32): if df > 1 and n_iter == iter_start: ll0 = lr0 = float('inf') # derivatives of likelihood term df1 = 1 if n_iter < iter_start else df ll, g, h = derivatives(x, fa, tr, theta[0], theta[1], theta[2], prec, df1, g, h, g1, h1) # derivatives of regularization term reg = spatial.regulariser(theta, **prm, absolute=1e-10) g += reg lr = 0.5 * dot(theta, reg) l, l0 = ll + lr, ll0 + lr0 gain = l0 - l print( f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}' ) # Gauss-Newton update h[:3] += 1e-8 * h[:3].abs().max(0).values h[:3] += 1e-5 if n_iter % 2: delta = torch.zeros_like(theta) prm1 = dict(membrane=prm['membrane'][:2], bending=prm['bending'][:2]) hh = torch.stack([h[0], h[1], h[3]]) delta[:2] = spatial.solve_field(hh, g[:2], **prm1, max_iter=16, absolute=1e-10) del hh else: delta = torch.zeros_like(theta) prm1 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1]) delta[-1:] = spatial.solve_field(h[2:3], g[2:3], **prm1, max_iter=16, absolute=1e-10) # delta = spatial.solve_field(h, g, **prm, max_iter=16, absolute=1e-10) # theta -= spatial.solve_field_fmg(h, g, **prm) # line search dd = spatial.regulariser(delta, **prm, absolute=1e-10) dt = dot(dd, theta) dd = dot(dd, delta) success = False armijo = 1 theta0 = theta ll0, lr0 = ll, lr for n_ls in range(12): theta = theta0.sub(delta, alpha=armijo) ll = nll(x, fa, tr, *theta, prec, df1) lr = 0.5 * armijo * (armijo * dd - 2 * dt) if ll + lr < ll0: # and theta[1].max() < 0.69: print(n_ls, 'success', ((ll + lr) / n).item(), (ll0 / n).item()) success = True break print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item()) armijo /= 2 if not success and n_iter > 5: theta = theta0 break # delta = spatial.solve_field_fmg(h, g, **prm) # # # line search # dd = spatial.regulariser(delta, **prm) # dt = dot(dd, theta) # dd = dot(dd, delta) # success = False # armijo = 1 # theta0 = theta # ll0, lr0 = ll, lr # for n_ls in range(12): # theta = theta0.sub(delta, alpha=armijo) # ll = nll(x, fa, tr, theta[0], theta[1], theta[2]) * prec # lr = 0.5 * armijo * (armijo * dd - 2 * dt) # if ll + lr < ll0: # print('success', n_ls) # success = True # break # armijo /= 2 # if not success: # theta = theta0 # break import matplotlib.pyplot as plt # plt.subplot(2, 2, 1) # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp()) # plt.axis('off') # plt.colorbar() # plt.title('PD * exp(-TE * R2*)') # plt.subplot(2, 2, 2) # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp()) # plt.axis('off') # plt.colorbar() # plt.title('R1') # plt.subplot(2, 2, 3) # plt.imshow(theta[2, :, :, theta.shape[-1]//2].exp()) # plt.axis('off') # plt.colorbar() # plt.title('B1+') # plt.show() vmin, vmax = 0, 2 * mu y = flash_signals(fa, tr, *theta) ex = 3 plt.rcParams["figure.figsize"] = (4, len(x) + ex) for i in range(len(x)): plt.subplot(len(x) + ex, 4, 4 * i + 1) plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax) plt.axis('off') plt.subplot(len(x) + ex, 4, 4 * i + 2) plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax) plt.axis('off') plt.subplot(len(x) + ex, 4, 4 * i + 3) plt.imshow(x[i, ..., x.shape[-1] // 2] - y[i, ..., y.shape[1] // 2], cmap=plt.get_cmap('coolwarm')) plt.axis('off') plt.colorbar() plt.subplot(len(x) + ex, 4, 4 * i + 4) plt.imshow( (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi) plt.axis('off') plt.colorbar() all_fa = torch.linspace(0, 2 * pymath.pi, 512) loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2), (2 * theta.shape[1] // 3, theta.shape[2] // 2, theta.shape[3] // 2), (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)] for j, (nx, ny, nz) in enumerate(loc): plt.subplot(len(x) + ex, 1, len(x) + j + 1) plt.plot( all_fa, flash_signals(all_fa, tr[:1] * 512, *theta[:, nx, ny, nz])) plt.scatter(fa, x[:, nx, ny, nz]) plt.show() if gain < 1e-4 * n: break theta.exp_() return theta[0], theta[1], theta[2]