def __init__(self, *args, **kwargs): self._greeq_volume = None self._greeq_affine = None super().__init__(*args, **kwargs) if len(self) not in (3, 4): raise ValueError('Exected 3 or 4 maps') self.pd = ParameterMap(self.volume[0], affine=self.affine) self.r1 = ParameterMap(self.volume[1], affine=self.affine) self.r2s = ParameterMap(self.volume[2], affine=self.affine) if len(self) == 4: self.mt = ParameterMap(self.volume[3], affine=self.affine)
def fit(self, data): # --- be polite ------------------------------------------------ if self.opt.verbose > 0: print(f'Fitting a (multi) exponential decay model with ' f'{len(data)} contrasts. Echo times:') for i, contrast in enumerate(data): print(f' - contrast {i:2d}: [' + ', '.join([f'{te * 1e3:.1f}' for te in contrast.te]) + '] ms') # --- estimate noise / register / initialize maps -------------- self.data, self.maps, self.dist = preproc(data, self.opt) self.affine0 = self.maps.affine self.shape0 = self.maps.decay.shape # --- prepare regularization factor ---------------------------- # -> we want lam = [*lam_intercepts, lam_decay] *lam, lam_decay = self.opt.regularization.factor lam = core.py.ensure_list(lam, self.nb_contrasts) lam.append(lam_decay) if not any(lam): self.opt.regularization.norm = '' self.lam = lam # --- initialize weights (RLS) --------------------------------- self.rls = None if self.opt.regularization.norm.endswith('tv'): rls_shape = self.shape if self.opt.regularization.norm == 'tv': rls_shape = (len(self.maps), *rls_shape) self.rls = ParameterMap(rls_shape, fill=1, **self.backend).volume # --- initialize nb of iterations ------------------------------ if not self.opt.regularization.norm.endswith('tv'): # no reweighting -> do more gauss-newton updates instead self.opt.optim.max_iter_gn *= self.opt.optim.max_iter_rls self.opt.optim.max_iter_rls = 1 # --- be polite (bis) ------------------------------------------ self.print_opt() self.print_header() # --- main loop ------------------------------------------------ self.loop() # --- prepare output ----------------------------------------------- out = postproc(self.maps, self.data) if self.opt.distortion.enable: out = (*out, self.dist) return out
def intercepts(self): if self.volume is None: return [] return [ ParameterMap(self.volume[i], affine=self.affine) for i in range(len(self) - 1) ]
def preproc(data, opt): """Estimate noise variance + register + compute recon space + init maps Parameters ---------- data : sequence[GradientEchoMulti] opt : Options Returns ------- data : sequence[GradientEchoMulti] maps : ParametersMaps """ if opt is None: opt = ESTATICSOptions() dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) # --- estimate hyper parameters --- logmeans = [] te = [] for c, contrast in enumerate(data): means = [] vars = [] for e, echo in enumerate(contrast): if opt.verbose: print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}', end='\r') dat = echo.fdata(**backend, rand=True, cache=False) sd0, sd1, mu0, mu1 = estimate_noise(dat) echo.mean = mu1.item() echo.sd = sd0.item() means.append(mu1) vars.append(sd0.square()) means = torch.stack(means) vars = torch.stack(vars) var = (means*vars).sum() / means.sum() contrast.noise = var.item() te.append(contrast.te) logmeans.append(means.log()) if opt.verbose: print('') # --- initial minifit --- print('Compute initial parameters') inter, decay = _loglin_minifit(logmeans, te) print(' - log intercepts: [' + ', '.join([f'{i:.1f}' for i in inter.tolist()]) + ']') print(f' - decay: {decay.tolist():.3g}') # --- initial align --- if opt.preproc.register and len(data) > 1: print('Register volumes') data_reg = [(contrast.echo(0).fdata(rand=True, cache=False, **backend), contrast.affine) for contrast in data] dats, affines, _ = affine_align(data_reg, device=device) if opt.verbose > 1 and plt: plt.figure() for i in range(len(dats)): plt.subplot(1, len(dats), i+1) plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu()) plt.show() for contrast, aff in zip(data, affines): aff, contrast.affine = core.utils.to_max_device(aff, contrast.affine) contrast.affine = torch.matmul(aff.inverse(), contrast.affine) # --- compute recon space --- affines = [contrast.affine for contrast in data] shapes = [dat.volume.shape[1:] for dat in data] if opt.recon.affine is None: opt.recon.affine = opt.recon.space if opt.recon.fov is None: opt.recon.fov = opt.recon.space if isinstance(opt.recon.affine, int): mean_affine = affines[opt.recon.affine] else: mean_affine = torch.as_tensor(opt.recon.affine) if isinstance(opt.recon.fov, int): mean_shape = shapes[opt.recon.fov] else: mean_shape = tuple(opt.recon.fov) # --- allocate maps --- maps = ESTATICSParameterMaps() maps.intercepts = [ParameterMap(mean_shape, fill=inter[c], affine=mean_affine, **backend) for c in range(len(data))] maps.decay = ParameterMap(mean_shape, fill=decay, affine=mean_affine, min=0, **backend) maps.affine = mean_affine return data, maps
def gre(pd, r1, r2s=None, mt=None, transmit=None, receive=None, gfactor=None, te=0, tr=25e-3, fa=20, mtpulse=False, sigma=None, noise='rician', affine=None, shape=None, device=None): """Simulate data generated by a Gradient-Echo (FLASH) sequence. Tissue parameters ----------------- pd : ParameterMap or tensor_like Proton density r1 : ParameterMap or tensor_like Longitudinal relaxation rate, in 1/sec r2s : ParameterMap, optional Transverse relaxation rate, in 1/sec. Mandatory if any `te > 0`. mt : ParameterMap, optional MTsat. Mandatory if any `mtpulse == True`. Fields ------ transmit : (N-sequence of) PrecomputedFieldMap or tensor_like, optional Transmit B1 field receive : (N-sequence of) PrecomputedFieldMap or tensor_like, optional Receive B1 field gfactor : (N-sequence of) PrecomputedFieldMap or tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- te : ((N-sequence of) M-sequence of) float, default=0 Echo time, in sec tr : (N-sequence of) float default=2.5e-3 Repetition time, in sec fa : (N-sequence of) float, default=20 Flip angle, in deg mtpulse : (N-sequence of) bool, default=False Presence of an off-resonance pulse Noise ----- sigma : (N-sequence of) float, optional Standard-deviation of the sampled noise (no sampling if `None`) noise : {'rician', 'gaussian'}, default='rician' Noise distribution Space ----- affine : ([N], 4, 4) tensor, optional Orientation matrix of the simulation space shape : (N-sequence of) sequence[int], default=pd.shape Shape of the simulation space Returns ------- sim : (N-sequence of) GradientEchoMulti Simulated series of multi-echo GRE images """ # 1) Find out the number of contrasts requested te = make_list(te) if any(map(lambda x: isinstance(x, (list, tuple)), te)): te = [make_list(t) for t in te] else: te = [te] tr = make_list(tr) fa = make_list(fa) mtpulse = make_list(mtpulse) mtpulse = [bool(p) for p in mtpulse] sigma = make_list(sigma) transmit = make_list(transmit or []) receive = make_list(receive or []) gfactor = make_list(gfactor or []) shape = make_list(shape) if any(map(lambda x: isinstance(x, (list, tuple)), shape)): shape = [make_list(s) for s in shape] else: shape = [shape] if torch.is_tensor(affine): affine = [affine] if affine.dim() == 2 else affine.unbind(0) else: affine = make_list(affine) nb_contrasts = max(len(te), len(tr), len(fa), len(mtpulse), len(sigma), len(transmit), len(receive), len(gfactor), len(shape), len(affine)) # 2) Pad all lists up to `nb_contrasts` te = make_list(te, nb_contrasts) tr = make_list(tr, nb_contrasts) fa = make_list(fa, nb_contrasts) mtpulse = make_list(mtpulse, nb_contrasts) sigma = make_list(sigma, nb_contrasts) transmit = make_list(transmit or [None], nb_contrasts) receive = make_list(receive or [None], nb_contrasts) gfactor = make_list(gfactor or [None], nb_contrasts) shape = make_list(shape, nb_contrasts) affine = make_list(affine, nb_contrasts) # 3) ensure parameters are `ParameterMap`s has_r2s = r2s is not None has_mt = mt is not None if not isinstance(pd, ParameterMap): pd = ParameterMap(pd) if not isinstance(r1, ParameterMap): r1 = ParameterMap(r1) if has_r2s and not isinstance(r2s, ParameterMap): r2s = ParameterMap(r2s) if has_mt and not isinstance(mt, ParameterMap): mt = ParameterMap(mt) mt.unit = '%' # 4) ensure all fields are `PrecomputedFieldMap`s for n in range(nb_contrasts): if (transmit[n] is not None and not isinstance(transmit[n], PrecomputedFieldMap)): transmit[n] = PrecomputedFieldMap(transmit[n]) if (receive[n] is not None and not isinstance(receive[n], PrecomputedFieldMap)): receive[n] = PrecomputedFieldMap(receive[n]) if (gfactor[n] is not None and not isinstance(gfactor[n], PrecomputedFieldMap)): gfactor[n] = PrecomputedFieldMap(gfactor[n]) # 5) choose backend all_var = [te, tr, fa, mtpulse, sigma, affine] all_var += [ f.volume for f in transmit if f is not None and torch.is_tensor(f.volume) ] all_var += [ f.volume for f in receive if f is not None and torch.is_tensor(f.volume) ] all_var += [ f.volume for f in gfactor if f is not None and torch.is_tensor(f.volume) ] all_var += [pd.volume] if torch.is_tensor(pd.volume) else [] all_var += [r1.volume] if torch.is_tensor(r1.volume) else [] all_var += [r2s.volume ] if r2s is not None and torch.is_tensor(r2s.volume) else [] all_var += [mt.volume ] if mt is not None and torch.is_tensor(mt.volume) else [] backend = core.utils.max_backend(*all_var) if device: backend['device'] = device # 6) prepare parameter maps prm = stack_maps(pd.fdata(**backend), r1.fdata(**backend), r2s.fdata(**backend) if r2s is not None else None, mt.fdata(**backend) if mt is not None else None) fpd, fr1, fr2s, fmt = unstack_maps(prm, has_r2s, has_mt) if has_mt and mt.unit in ('%', 'pct', 'p.u.'): fmt /= 100. if any(aff is not None for aff in affine): logprm = stack_maps( safelog(fpd), safelog(fr1), safelog(fr2s) if r2s is not None else None, safelog(fmt) + safelog(1 - fmt) if mt is not None else None) # 7) generate noise-free signal contrasts = [] for n in range(nb_contrasts): shape1 = shape[n] if shape1 is None: shape1 = pd.shape aff1 = affine[n] if aff1 is not None: aff1 = aff1.to(**backend) te1 = torch.as_tensor(te[n], **backend) tr1 = torch.as_tensor(tr[n], **backend) fa1 = torch.as_tensor(fa[n], **backend) / 180. * core.constants.pi sigma1 = torch.as_tensor(sigma[n], **backend) if sigma[n] else None mtpulse1 = mtpulse[n] transmit1 = transmit[n] receive1 = receive[n] gfactor1 = gfactor[n] if aff1 is not None: mat = core.linalg.lmdiv(pd.affine.to(**backend), aff1) grid = smart_grid(mat, shape1, pd.shape, force=True) prm1 = smart_pull(logprm, grid) inplace = grid is not None f1pd, f1r1, f1r2s, f1mt = unstack_maps(prm1, has_r2s, has_mt) f1pd, f1r1, f1r2s, f1mt = exp_maps(f1pd, f1r1, f1r2s, f1mt, inplace=inplace) else: f1pd, f1r1, f1r2s, f1mt = (fpd, fr1, fr2s, fmt) # clone it so that we can work in-place later f1pd = f1pd.clone() f1r1 = f1r1.clone() if f1r2s is not None: f1r2s = f1r2s.clone() if f1mt is not None: f1mt = f1mt.clone() if transmit1 is not None: unit = transmit1.unit taff1 = transmit1.affine.to(**backend) transmit1 = transmit1.fdata(**backend) if unit in ('%', 'pct', 'p.u'): transmit1 = transmit1 / 100. if aff1 is not None: mat = core.linalg.lmdiv(taff1, aff1) grid = smart_grid(mat, shape1, transmit1.shape) transmit1 = smart_pull(transmit1[None], grid)[0] del grid fa1 = fa1 * transmit1 del transmit1 if receive1 is not None: unit = receive1.unit raff1 = receive1.affine.to(**backend) receive1 = receive1.fdata(**backend) if unit in ('%', 'pct', 'p.u'): receive1 = receive1 / 100. if aff1 is not None: mat = core.linalg.lmdiv(raff1, aff1) grid = smart_grid(mat, shape1, receive1.shape) receive1 = smart_pull(receive1[None], grid)[0] del grid f1pd = f1pd * receive1 del receive1 # generate signal flash = f1pd flash *= fa1.sin() cosfa = fa1.cos_() e1 = f1r1 e1 *= -tr1 e1 = e1.exp_() flash *= (1 - e1) if mtpulse1: if not has_mt: raise ValueError('Cannot simulate an MT pulse: ' 'an MT must be provided.') omt = f1mt.neg_() omt += 1 flash *= omt flash /= (1 - cosfa * omt * e1) del omt else: flash /= (1 - cosfa * e1) del e1, cosfa, fa1, f1r1, f1mt # multiply with r2* if any(t > 0 for t in te1): if not has_r2s: raise ValueError('Cannot simulate an R2* decay: ' 'an R2* must be provided.') te1 = te1.reshape([-1] + [1] * f1r2s.dim()) flash = flash * (-te1 * f1r2s).exp_() del f1r2s # sample noise if sigma1: if gfactor1 is not None: gfactor1 = gfactor1.fdata(**backend) if gfactor1.unit in ('%', 'pct', 'p.u'): gfactor1 = gfactor1 / 100. if aff1 is not None: mat = core.linalg.lmdiv(gfactor1.affine.to(**backend), aff1) grid = smart_grid(mat, shape1, gfactor1.shape) gfactor1 = smart_pull(gfactor1[None], grid)[0] del grid sigma1 = sigma1 * gfactor1 del gfactor1 noise_shape = flash.shape if noise == 'rician': noise_shape = (2, ) + noise_shape sample = torch.randn(noise_shape, **backend) sample *= sigma1 del sigma1 if noise == 'rician': sample = sample.square_().sum(dim=0) flash = flash.square_().add_(sample).sqrt_() else: flash += sample del sample te1 = te1.tolist() tr1 = tr1.item() fa1 = torch.as_tensor(fa[n]).item() mtpulse1 = torch.as_tensor(mtpulse1).item() flash = GradientEchoMulti(flash, affine=aff1, tr=tr1, fa=fa1, te=te1, mt=mtpulse1) contrasts.append(flash) return contrasts[0] if len(contrasts) == 1 else contrasts
class _ESTATICS_nonlin: IMAGE_BOUND = 'dft' DIST_BOUND = 'dct2' def __init__(self, opt): self.opt = ESTATICSOptions().update(opt).cleanup_() self.data = None self.maps = None self.dist = None self.rls = None self.lam = None self.shape0 = None self.affine0 = None self.lam_scale = 1 self.last_step = '' @property def numel(self): return sum(core.py.prod(dat.shape) for dat in self.data) @property def shape(self): return self.maps.decay.shape @property def affine(self): return self.maps.affine @property def voxel_size(self): return spatial.voxel_size(self.affine) @property def nb_contrasts(self): return len(self.maps) - 1 @property def lam_dist(self): return dict(factor=self.opt.distortion.factor, absolute=self.opt.distortion.absolute, membrane=self.opt.distortion.membrane, bending=self.opt.distortion.bending) @property def backend(self): return dict(dtype=self.opt.backend.dtype, device=self.opt.backend.device) def iter_rls(self): if self.rls is None: for _ in range(self.nb_contrasts + 1): yield None elif self.rls.dim() == 3: for _ in range(self.nb_contrasts + 1): yield self.rls else: assert self.rls.dim() == 4 for rls1 in self.rls: yield rls1 def fit(self, data): # --- be polite ------------------------------------------------ if self.opt.verbose > 0: print(f'Fitting a (multi) exponential decay model with ' f'{len(data)} contrasts. Echo times:') for i, contrast in enumerate(data): print(f' - contrast {i:2d}: [' + ', '.join([f'{te * 1e3:.1f}' for te in contrast.te]) + '] ms') # --- estimate noise / register / initialize maps -------------- self.data, self.maps, self.dist = preproc(data, self.opt) self.affine0 = self.maps.affine self.shape0 = self.maps.decay.shape # --- prepare regularization factor ---------------------------- # -> we want lam = [*lam_intercepts, lam_decay] *lam, lam_decay = self.opt.regularization.factor lam = core.py.ensure_list(lam, self.nb_contrasts) lam.append(lam_decay) if not any(lam): self.opt.regularization.norm = '' self.lam = lam # --- initialize weights (RLS) --------------------------------- self.rls = None if self.opt.regularization.norm.endswith('tv'): rls_shape = self.shape if self.opt.regularization.norm == 'tv': rls_shape = (len(self.maps), *rls_shape) self.rls = ParameterMap(rls_shape, fill=1, **self.backend).volume # --- initialize nb of iterations ------------------------------ if not self.opt.regularization.norm.endswith('tv'): # no reweighting -> do more gauss-newton updates instead self.opt.optim.max_iter_gn *= self.opt.optim.max_iter_rls self.opt.optim.max_iter_rls = 1 # --- be polite (bis) ------------------------------------------ self.print_opt() self.print_header() # --- main loop ------------------------------------------------ self.loop() # --- prepare output ----------------------------------------------- out = postproc(self.maps, self.data) if self.opt.distortion.enable: out = (*out, self.dist) return out def print_opt(self): if self.opt.verbose <= 1: return if self.opt.regularization.norm: print('Regularization:') print( f' - type: {self.opt.regularization.norm.upper()}' ) print(f' - log intercepts: [' + ', '.join([f'{i:.3g}' for i in self.lam[:-1]]) + ']') print(f' - decay: {self.lam[-1]:.3g}') else: print('Without regularization') if self.opt.distortion.enable: print('Distortion correction:') print(f' - model: {self.opt.distortion.model.lower()}') print( f' - absolute: {self.opt.distortion.absolute * self.opt.distortion.factor}' ) print( f' - membrane: {self.opt.distortion.membrane * self.opt.distortion.factor}' ) print( f' - bending: {self.opt.distortion.bending * self.opt.distortion.factor}' ) print( f' - te_scaling: {self.opt.distortion.te_scaling or "no"}' ) else: print('Without distortion correction') print('Optimization:') if self.opt.regularization.norm.endswith('tv'): print(f' - IRLS iterations: {self.opt.optim.max_iter_rls}' f' (tolerance: {self.opt.optim.tolerance_rls})') print(f' - GN iterations: {self.opt.optim.max_iter_gn}' f' (tolerance: {self.opt.optim.tolerance_gn})') print(f' - FMG cycles: 2') print(f' - CG iterations: {self.opt.optim.max_iter_cg}' f' (tolerance: {self.opt.optim.tolerance_cg})') def loop(self): """Nested optimization loops""" rls = self.rls.reciprocal().sum() if self.rls is not None else 0 self.nll = dict(obs=float('inf'), obs_prev=float('inf'), reg=0, reg_prev=0, vreg=0, vreg_prev=0, rls=rls, rls_prev=rls, all=[]) nll = float('inf') nll_scl = self.numel * len(self.data) # --- Multi-Resolution loop ------------------------------------ for self.level in range(self.opt.optim.nb_levels, 0, -1): if self.opt.optim.nb_levels > 1: self.resize() # --- RLS loop --------------------------------------------- # > max_iter_rls == 1 if regularization is not (J)TV for self.n_iter_rls in range(1, self.opt.optim.max_iter_rls + 1): # --- Gauss-Newton (prm + dist) ------------------------ for self.n_iter_gn in range(1, self.opt.optim.max_iter_gn + 1): nll0_gn = nll self.n_iter_dist = 0 self.n_iter_prm = 0 # --- Gauss-Newton (prm) --------------------------- max_iter_prm = self.opt.optim.max_iter_prm if self.n_iter_gn == 1 and self.n_iter_rls == 1: max_iter_prm = max_iter_prm * 2 for self.n_iter_prm in range(1, max_iter_prm + 1): nll0_prm = nll nll = self.update_prm() if (self.n_iter_gn > 1 and (nll0_prm - nll) < self.opt.optim.tolerance * nll_scl): break # ---------------------------------------------- # this is where we should check for RLS (== global) gain if (self.n_iter_prm == 0 and self.n_iter_gn == 0 and self.n_iter_rls > 1 and (nll0_rls - nll) < self.opt.optim.tolerance * nll_scl * 2): return # ---------------------------------------------- # --- Gauss-Newton (dist) -------------------------- for self.n_iter_dist in range( 1, self.opt.optim.max_iter_dist + 1): nll0_dist = nll nll = self.update_dist() if (self.n_iter_dist > 1 and (nll0_dist - nll) < self.opt.optim.tolerance * nll_scl): break if (self.n_iter_gn > 1 and (nll0_gn - nll) < self.opt.optim.tolerance * nll_scl): break nll0_rls = nll nll = self.update_rls() # ------------------------------------------------------------------ # UPDATE PARAMETERS MAPS # ------------------------------------------------------------------ def momentum_prm(self, dat): """Momentum of the parameter maps""" return spatial.regulariser(dat, weights=self.rls, dim=3, **self.lam_prm, voxel_size=self.voxel_size) def solve_prm(self, hess, grad): """Solve Newton step""" hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) delta = spatial.solve_field_fmg(hess, grad, self.rls, **self.lam_prm, voxel_size=self.voxel_size) delta = check_nans_(delta, warn='delta') return delta def update_prm(self): """Update parameter maps (log-intercept and decay)""" nmaps = len(self.data) grad = torch.zeros((nmaps + 1, *self.shape), **self.backend) hess = torch.zeros((nmaps * 2 + 1, *self.shape), **self.backend) # --- loop over contrasts -------------------------------------- iterator = zip(self.data, self.maps.intercepts, self.dist) nll = 0 for i, (contrast, intercept, distortion) in enumerate(iterator): # compute gradient nll1, g1, h1 = self.derivatives_prm(contrast, distortion, intercept, self.maps.decay) # increment gind = [i, -1] grad[gind] += g1 hind = [i, len(grad) - 1, len(grad) + i] hess[hind] += h1 nll += nll1 del g1, h1 # --- regularization ------------------------------------------ reg = 0 if self.opt.regularization.norm: g1 = self.momentum_prm(self.data.volume) reg = 0.5 * dot(g1, self.data.volume) grad += g1 del g1 # --- gauss-newton --------------------------------------------- # Computing the GN step involves solving H\g deltas = self.solve_prm(hess, grad) # No need for a line search (hopefully) for map, delta in zip(self.maps, deltas): map.volume -= delta if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) # --- track general improvement -------------------------------- self.nll['obs_prev'] = self.nll['obs'] self.nll['obs'] = nll self.nll['reg_prev'] = self.nll['reg'] self.nll['reg'] = reg nll = self.print_nll() self.last_step = 'prm' return nll # ------------------------------------------------------------------ # UPDATE DISTORTION MAPS # ------------------------------------------------------------------ def momentum_dist(self, dat, vx, readout): """Momentum of the distortion maps""" lam = dict(self.lam_dist) lam['factor'] = lam['factor'] * (vx[readout]**2) return spatial.regulariser(dat[None], **self.lam_dist, dim=3, bound=self.DIST_BOUND, voxel_size=vx)[0] def solve_dist(self, hess, grad, vx, readout): """Solve Newton step""" hess = check_nans_(hess, warn='hessian') hess = hessian_loaddiag_(hess, 1e-6, 1e-8) lam = dict(self.lam_dist) lam['factor'] = lam['factor'] * (vx[readout]**2) delta = spatial.solve_field_fmg(hess, grad, **self.lam_dist, dim=3, bound=self.DIST_BOUND, voxel_size=vx) delta = check_nans_(delta, warn='delta') return delta def update_dist(self): """Update distortions""" nll = 0 reg = 0 # --- loop over contrasts -------------------------------------- iterator = zip(self.data, self.maps.intercepts, self.dist) for i, (contrast, intercept, distortion) in enumerate(iterator): momentum = lambda dat: self.momentum_dist( dat, distortion.voxel_size, contrast.readout) solve = lambda h, g: self.solve_dist(h, g, distortion.voxel_size, contrast.readout) vol = distortion.volume # --- likelihood ------------------------------------------- nll1, g, h = self.derivatives_dist(contrast, distortion, intercept, self.maps.decay) # --- regularization --------------------------------------- g1 = momentum(vol) reg1 = 0.5 * dot(g1, vol) g += g1 del g1 # --- gauss-newton ----------------------------------------- delta = solve(h, g) del g, h # --- line search ------------------------------------------ armijo, armijo_prev = 1, 0 dd = momentum(delta) dv = dot(dd, vol) dd = dot(dd, delta) success = False for n_ls in range(12): vol.sub_(delta, alpha=(armijo - armijo_prev)) armijo_prev = armijo delta_reg1 = 0.5 * armijo * (armijo * dd - 2 * dv) new_nll1 = self.derivatives_dist(contrast, distortion, intercept, self.maps.decay, do_grad=False) if new_nll1 + delta_reg1 <= nll1: success = True break armijo /= 2 if not success: vol.add_(delta, alpha=armijo_prev) new_nll1 = nll1 delta_reg1 = 0 nll += new_nll1 reg += reg1 + delta_reg1 del delta # --- track general improvement -------------------------------- self.nll['obs_prev'] = self.nll['obs'] self.nll['obs'] = nll self.nll['vreg_prev'] = self.nll['vreg'] self.nll['vreg'] = reg nll = self.print_nll() self.last_step = 'dist' return nll # ------------------------------------------------------------------ # UPDATE WEIGHT MAP # ------------------------------------------------------------------ def update_rls(self): if self.opt.regularization.norm not in ('tv', 'jtv'): return sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls']) rls, sumrls = update_rls(self.maps, self.lam, self.opt.regularization.norm) self.nll['rls_prev'] = self.nll['rls'] self.nll['rls'] = 0.5 * sumrls self.last_step = 'rls' return sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls']) def print_header(self): if self.opt.verbose <= 0: return pstr = '' if self.opt.optim.max_iter_rls > 1: pstr += f'{"rls":^3s} | ' if self.opt.optim.max_iter_gn > 1: pstr += f'{"gn":^3s} | ' if self.opt.optim.max_iter_prm > 1 or self.opt.optim.max_iter_dist > 1: pstr += f'{"sub":^3s} | ' pstr += f'{"step":^4s} | ' pstr += f'{"fit":^12s} ' if self.opt.regularization.norm: pstr += f'+ {"reg":^12s} ' if self.opt.regularization.norm.endswith('tv'): pstr += f'+ {"rls":^12s} ' if self.opt.distortion.enable: pstr += f'+ {"dist":^12s} ' pstr += f'= {"crit":^12s}' print('\n' + pstr) print('-' * len(pstr)) def print_nll(self): nll = sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls']) self.nll['all'].append(nll) if self.opt.verbose <= 0: return nll obs = self.nll['obs'] reg = self.nll['reg'] vreg = self.nll['vreg'] rls = self.nll['rls'] obs0 = self.nll['obs_prev'] reg0 = self.nll['reg_prev'] vreg0 = self.nll['vreg_prev'] rls0 = self.nll['rls_prev'] if self.last_step == 'rls': nll = obs0 + reg + rls + vreg if reg + rls <= reg0 + rls0: evol = '<=' else: evol = '>' elif self.last_step == 'prm': nll = obs0 + reg + rls + vreg if obs + reg <= obs0 + reg0: evol = '<=' else: evol = '>' elif self.last_step == 'dist': if obs + vreg <= obs0 + vreg0: evol = '<=' else: evol = '>' else: evol = '' nll = obs + reg + vreg + rls pstr = '' if self.opt.optim.max_iter_rls > 1: pstr += f'{self.n_iter_rls:3d} | ' if self.opt.optim.max_iter_gn > 1: pstr += f'{self.n_iter_gn:3d} | ' if self.opt.optim.max_iter_prm > 1 or self.opt.optim.max_iter_dist > 1: if self.last_step == 'rls': pstr += f'{"-"*3} | ' else: pstr += f'{self.n_iter_prm+self.n_iter_dist:3d} | ' pstr += f'{self.last_step:4s} | ' pstr += f'{obs:12.6g} ' if self.opt.regularization.norm: pstr += f'+ {reg:12.6g} ' if self.opt.regularization.norm.endswith('tv'): pstr += f'+ {rls:12.6g} ' if self.opt.distortion.enable: pstr += f'+ {vreg:12.6g} ' pstr += f'= {nll:12.6g} | ' pstr += f'{evol}' print(pstr) self.show_maps() return nll def resize(self): affine, shape = spatial.affine_resize(self.affine0, self.shape0, 1 / (2**(self.level - 1))) scl0 = spatial.voxel_size(self.affine0).prod() scl = spatial.voxel_size(affine).prod() / scl0 self.lam_scale = scl for map in self.maps: map.volume = spatial.resize(map.volume[None, None, ...], shape=shape)[0, 0] map.affine = affine self.maps.affine = affine if self.rls is not None: if self.rls.dim() == len(shape): self.rls = spatial.resize(self.rls[None, None], hape=shape)[0, 0] else: self.rls = spatial.resize(self.rls[None], shape=shape)[0] self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double) def show_maps(self): if not self.opt.plot: return import matplotlib.pyplot as plt has_dist = any([d is not None for d in self.dist]) ncol = max(len(self.maps), len(self.dist)) for i, map in enumerate(self.maps): plt.subplot(2 + has_dist, ncol, i + 1) vol = map.volume[:, :, map.shape[-1] // 2] if i < len(self.maps) - 1: vol = vol.exp() plt.imshow(vol.cpu()) plt.axis('off') if i < len(self.maps) - 1: plt.title('TE=0') else: plt.title('R2*') plt.colorbar() if has_dist: for i, (dat, dst) in enumerate(zip(self.data, self.dist)): if dst is None: continue readout = dat.readout vol = dst.volume plt.subplot(2 + has_dist, ncol, i + 1 + ncol) vol = vol[:, :, dst.shape[-2] // 2] if readout is not None: vol = vol[..., readout] else: vol = vol.square().sum(-1).sqrt() plt.imshow(vol.cpu()) plt.axis('off') plt.colorbar() plt.subplot(2 + has_dist, 1, 2 + has_dist) plt.plot(self.nll['all']) plt.show()
def preproc(data, transmit=None, receive=None, opt=None): """Estimate noise variance + register + compute recon space + init maps Parameters ---------- data : sequence[GradientEchoMulti] transmit : sequence[PrecomputedFieldMap], optional receive : sequence[PrecomputedFieldMap], optional opt : Options, optional Returns ------- data : sequence[GradientEchoMulti] maps : ParametersMaps """ opt = GREEQOptions().update(opt) dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) # --- estimate hyper parameters --- logmeans = [] te = [] tr = [] fa = [] mt = [] for c, contrast in enumerate(data): means = [] vars = [] for e, echo in enumerate(contrast): if opt.verbose: print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}', end='\r') dat = echo.fdata(**backend, rand=True, cache=False) sd0, sd1, mu0, mu1 = estimate_noise(dat) echo.mean = mu1.item() echo.sd = sd0.item() means.append(mu1) vars.append(sd0.square()) means = torch.stack(means) vars = torch.stack(vars) var = (means * vars).sum() / means.sum() contrast.noise = var.item() te.append(contrast.te) tr.append(contrast.tr) fa.append(contrast.fa / 180 * core.constants.pi) mt.append(contrast.mt) logmeans.append(means.log()) if opt.verbose: print('') print('Estimating maps from volumes:') for i in range(len(data)): print(f' - Contrast {i:d}: ', end='') print(f'FA = {fa[i]*180/core.constants.pi:2.0f} deg / ', end='') print(f'TR = {tr[i]*1e3:4.1f} ms / ', end='') print('TE = [' + ', '.join([f'{t*1e3:.1f}' for t in te[i]]) + '] ms', end='') if mt[i]: print(f' / MT = True', end='') print() # --- initial minifit --- print('Compute initial parameters') inter, r2s = _loglin_minifit(logmeans, te) pd, r1, mt = _rational_minifit(inter, tr, fa, mt) print(f' - PD: {pd.tolist():9.3g} a.u.') print(f' - R1: {r1.tolist():9.3g} 1/s') print(f' - R2*: {r2s.tolist():9.3g} 1/s') pd = pd.log() r1 = r1.log() r2s = r2s.log() if mt is not None: print(f' - MT: {100*mt.tolist():9.3g} %') mt = mt.log() - (1 - mt).log() # --- initial align --- transmit = core.utils.make_list(transmit or []) receive = core.utils.make_list(receive or []) if opt.preproc.register and len(data) > 1: print('Register volumes') data_reg = [(contrast.echo(0).fdata(rand=True, cache=False, **backend), contrast.affine) for contrast in data] data_reg += [(map.magnitude.fdata(rand=True, cache=False, **backend), map.magnitude.affine) for map in transmit] data_reg += [(map.magnitude.fdata(rand=True, cache=False, **backend), map.magnitude.affine) for map in receive] dats, affines, _ = affine_align(data_reg, device=device) if opt.verbose > 1 and plt: plt.figure() for i in range(len(dats)): plt.subplot(1, len(dats), i + 1) plt.imshow(dats[i, :, dats.shape[2] // 2, :].cpu()) plt.axis('off') plt.show() for contrast, aff in zip(data + transmit + receive, affines): aff, contrast.affine = core.utils.to_max_device( aff, contrast.affine) contrast.affine = torch.matmul(aff.inverse(), contrast.affine) # --- compute recon space --- affines = [contrast.affine for contrast in data] shapes = [dat.volume.shape[1:] for dat in data] if opt.recon.affine is None: opt.recon.affine = opt.recon.space if opt.recon.fov is None: opt.recon.fov = opt.recon.space if isinstance(opt.recon.affine, int): mean_affine = affines[opt.recon.affine] else: mean_affine = torch.as_tensor(opt.recon.affine) if isinstance(opt.recon.fov, int): mean_shape = shapes[opt.recon.fov] else: mean_shape = tuple(opt.recon.fov) # --- allocate maps --- maps = GREEQParameterMaps() maps.pd = ParameterMap(mean_shape, fill=pd, affine=mean_affine, **backend) maps.r1 = ParameterMap(mean_shape, fill=r1, affine=mean_affine, **backend) maps.r2s = ParameterMap(mean_shape, fill=r2s, affine=mean_affine, **backend) if mt is not None: maps.mt = ParameterMap(mean_shape, fill=mt, affine=mean_affine, **backend) maps.affine = mean_affine # --- repeat fields if not enough --- if transmit: transmit = core.py.make_list(transmit, len(data)) else: transmit = [None] * len(data) if receive: receive = core.py.make_list(receive, len(data)) else: receive = [None] * len(data) return data, transmit, receive, maps
def vfa(data, transmit=None, receive=None, opt=None, **kwopt): """Compute PD, R1 (and MTsat) from two GRE at two flip angles using rational approximations of the Ernst equations. Parameters ---------- data : sequence[GradientEcho] Volumes with different contrasts (flip angle or MT pulse) but with the same echo time. Note that they do not need to be real echoes; they often are images extrapolated to TE = 0. transmit : sequence[PrecomputedFieldMap], optional Map(s) of the transmit field (b1+). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. receive : sequence[PrecomputedFieldMap], optional Map(s) of the receive field (b1-). If a single map is provided, it is used to correct all contrasts. If multiple maps are provided, there should be one for each contrast. If no receive map is provided, the output `pd` map will have a remaining b1- bias field. opt : VFAOptions Algorithm options. {'preproc': {'register': True}, # Co-register contrasts 'backend': {'dtype': torch.float32, # Data type 'device': 'cpu'}, # Device 'verbose': 1, # Verbosity: 1=print, 2=plot 'rational': False} # Force rational approximation Returns ------- pd : ParameterMap Proton density (potentially, with R2* bias) r1 : ParameterMap Longitudinal relaxation rate mt : ParameterMap, optional Magnetization transfer References ---------- ..[1] Tabelow et al., "hMRI - A toolbox for quantitative MRI in neuroscience and clinical research", NeuroImage (2019). https://www.sciencedirect.com/science/article/pii/S1053811919300291 ..[2] Helms et al., "Quantitative FLASH MRI at 3T using a rational approximation of the Ernst equation", MRM (2008). https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732 ..[3] Helms et al., "High-resolution maps of magnetization transfer with inherent correction for RF inhomogeneity and T1 relaxation obtained from 3D FLASH MRI", MRM (2008). https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732 """ opt = VFAOptions().update(opt, **kwopt) dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) data = core.py.make_list(data) if len(data) < 2: raise ValueError('Expected at least two input images') transmit = core.py.make_list(transmit or []) receive = core.py.make_list(receive or []) # --- Copy instances to avoid modifying the inputs --- data = list(map(lambda x: x.copy(), data)) transmit = list(map(lambda x: x.copy(), transmit)) receive = list(map(lambda x: x.copy(), receive)) # --- check TEs --- if len(set([contrast.te for contrast in data])) > 1: raise ValueError('Echo times not consistent across contrasts') # --- register --- if opt.preproc.register: print('Register volumes') data_reg = [(contrast.fdata(rand=True, cache=False, **backend), contrast.affine) for contrast in data] data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend), fmap.magnitude.affine) for fmap in transmit] data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend), fmap.magnitude.affine) for fmap in receive] dats, affines, _ = affine_align(data_reg, device=device, fwhm=3) if opt.verbose > 1 and plt: plt.figure() for i in range(len(dats)): plt.subplot(1, len(dats), i+1) plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu()) plt.axis('off') plt.suptitle('Registered magnitude images') plt.show() del dats for vol, aff in zip(data + transmit + receive, affines): aff, vol.affine = core.utils.to_max_device(aff, vol.affine) vol.affine = torch.matmul(aff.inverse(), vol.affine) # --- repeat fields if not enough --- transmit = core.py.make_list(transmit or [None], len(data)) receive = core.py.make_list(receive or [None], len(data)) # --- compute recon space --- affines = [contrast.affine for contrast in data] shapes = [dat.volume.shape for dat in data] if opt.recon.affine is None: opt.recon.affine = opt.recon.space if opt.recon.fov is None: opt.recon.fov = opt.recon.space if isinstance(opt.recon.affine, int): mean_affine = affines[opt.recon.affine] else: mean_affine = torch.as_tensor(opt.recon.affine) if isinstance(opt.recon.fov, int): mean_shape = shapes[opt.recon.fov] else: mean_shape = tuple(opt.recon.fov) # --- compute PD/R1 --- pdt1 = [(id, contrast) for id, contrast in enumerate(data) if not contrast.mt] if len(pdt1) > 2: warnings.warn('More than two volumes could be used to compute PD+R1') pdt1 = pdt1[:2] if len(pdt1) < 2: raise ValueError('Not enough volumes to compute PD+R1') (pdw_idx, pdw_struct), (t1w_idx, t1w_struct) = pdt1 if t1w_struct.te != pdw_struct.te: warnings.warn('Echo times not consistent across volumes') rational = opt .rational or t1w_struct.tr != pdw_struct.tr method = 'rational' if rational else 'analytical' print(f'Computing PD and R1 ({method}) from volumes:') print(f' - ' f'FA = {pdw_struct.fa:2.0f} deg / ' f'TR = {pdw_struct.tr*1e3:4.1f} ms / ' f'TE = {pdw_struct.te*1e3:4.1f} ms') pdw = load_and_pull(pdw_struct, mean_affine, mean_shape, **backend) pdw_fa = pdw_struct.fa / 180. * core.constants.pi pdw_tr = pdw_struct.tr if receive[pdw_idx]: b1m = load_and_pull(receive[pdw_idx], mean_affine, mean_shape, **backend) unit = receive[pdw_idx].unit minval = b1m[b1m > 0].min() maxval = b1m[b1m > 0].max() meanval = b1m[b1m > 0].mean() print(f' with B1- map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') pdw /= b1m if unit in ('%', 'pct', 'p.u.'): pdw *= 100 del b1m if transmit[pdw_idx]: b1p = load_and_pull(transmit[pdw_idx], mean_affine, mean_shape, **backend) unit = transmit[pdw_idx].unit minval = b1p[b1p > 0].min() maxval = b1p[b1p > 0].max() meanval = b1p[b1p > 0].mean() print(f' with B1+ map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') pdw_fa = b1p * pdw_fa if unit in ('%', 'pct', 'p.u.'): pdw_fa /= 100 del b1p pdw_fa = torch.as_tensor(pdw_fa, **backend) print(f' - ' f'FA = {t1w_struct.fa:2.0f} deg / ' f'TR = {t1w_struct.tr*1e3:4.1f} ms / ' f'TE = {t1w_struct.te*1e3:4.1f} ms') t1w = load_and_pull(t1w_struct, mean_affine, mean_shape, **backend) t1w_fa = t1w_struct.fa / 180. * core.constants.pi t1w_tr = t1w_struct.tr if receive[t1w_idx]: b1m = load_and_pull(receive[t1w_idx], mean_affine, mean_shape, **backend) unit = receive[t1w_idx].unit minval = b1m[b1m > 0].min() maxval = b1m[b1m > 0].max() meanval = b1m[b1m > 0].mean() print(f' with B1- map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') t1w /= b1m if unit in ('%', 'pct', 'p.u.'): t1w *= 100 del b1m if transmit[pdw_idx]: b1p = load_and_pull(transmit[t1w_idx], mean_affine, mean_shape, **backend) unit = transmit[t1w_idx].unit minval = b1p[b1p > 0].min() maxval = b1p[b1p > 0].max() meanval = b1p[b1p > 0].mean() print(f' with B1+ map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') t1w_fa = b1p * t1w_fa if unit in ('%', 'pct', 'p.u.'): t1w_fa /= 100 del b1p t1w_fa = torch.as_tensor(t1w_fa, **backend) if rational: # we must use rational approximations r1 = 0.5 * (t1w * (t1w_fa / t1w_tr) - pdw * (pdw_fa / pdw_tr)) r1 /= ((pdw / pdw_fa) - (t1w / t1w_fa)) pd = (t1w * pdw) * (t1w_tr * (pdw_fa / t1w_fa) - pdw_tr * (t1w_fa / pdw_fa)) pd /= (pdw * (pdw_tr * pdw_fa) - t1w * (t1w_tr * t1w_fa)) del t1w_fa, pdw_fa, t1w, pdw else: # there is an analytical solution cosfa_t1w = t1w_fa.cos() sinfa_t1w = t1w_fa.sin_() del t1w_fa cosfa_pdw = pdw_fa.cos() sinfa_pdw = pdw_fa.sin_() del pdw_fa e1 = sinfa_pdw * t1w - sinfa_t1w * pdw e1 /= sinfa_pdw * cosfa_t1w * t1w - sinfa_t1w * cosfa_pdw * pdw pd = t1w / sinfa_t1w * (1 - cosfa_t1w * e1) / (1 - e1) pd += pdw / sinfa_pdw * (1 - cosfa_pdw * e1) / (1 - e1) pd /= 2. r1 = e1.log_().neg_().div_(t1w_tr) del t1w, pdw, cosfa_pdw, cosfa_t1w r1[~torch.isfinite(r1)] = 0 pd[~torch.isfinite(pd)] = 0 # --- compute MTsat --- mtw_struct = [(id, contrast) for id, contrast in enumerate(data) if contrast.mt] if len(mtw_struct) == 0: return (ParameterMap(pd, affine=mean_affine, unit=None), ParameterMap(r1, affine=mean_affine, unit='1/s')) if len(mtw_struct) > 1: warnings.warn('More than one volume could be used to compute MTsat') mtw_idx, mtw_struct = mtw_struct[0] if mtw_struct.te != pdw_struct.te: warnings.warn('Echo times not consistent across volumes') method = 'rational' if opt.rational else 'analytical' print(f'Computing MTsat ({method}) from PD/R1 maps and volume:') print(f' - ' f'FA = {mtw_struct.fa:2.0f} deg / ' f'TR = {mtw_struct.tr*1e3:4.1f} ms / ' f'TE = {mtw_struct.te*1e3:4.1f} ms') mtw = load_and_pull(mtw_struct, mean_affine, mean_shape, **backend) mtw_fa = mtw_struct.fa / 180. * core.constants.pi mtw_tr = mtw_struct.tr if receive[mtw_idx]: b1m = load_and_pull(receive[mtw_idx], mean_affine, mean_shape, **backend) unit = receive[mtw_idx].unit minval = b1m[b1m > 0].min() maxval = b1m[b1m > 0].max() meanval = b1m[b1m > 0].mean() print(f' with B1- map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') mtw /= b1m if unit in ('%', 'pct', 'p.u.'): mtw *= 100 del b1m if transmit[mtw_idx]: b1p = load_and_pull(transmit[mtw_idx], mean_affine, mean_shape) unit = transmit[mtw_idx].unit minval = b1p[b1p > 0].min() maxval = b1p[b1p > 0].max() meanval = b1p[b1p > 0].mean() print(f' with B1+ map (' f'min= {minval:.2f}, ' f'max = {maxval:.2f}, ' f'mean = {meanval:.2f} {unit})') mtw_fa = b1p * mtw_fa if unit in ('%', 'pct', 'p.u.'): mtw_fa /= 100 del b1p mtw_fa = torch.as_tensor(mtw_fa, **backend) if opt.rational: # we must use rational approximations mtsat = (mtw_fa * pd / mtw - 1) * r1 * mtw_tr - 0.5 * (mtw_fa ** 2) del mtw_fa, mtw else: # we have an analytical solution (work backward from PD/R1) cosfa_mtw = mtw_fa.cos() sinfa_mtw = mtw_fa.sin() del mtw_fa e1 = (-mtw_tr * r1).exp() mtsat = mtw / (cosfa_mtw * e1 * mtw + pd * sinfa_mtw * (1 - e1)) del mtw, cosfa_mtw, sinfa_mtw mtsat = 1 - mtsat mtsat *= 100 mtsat[~torch.isfinite(mtsat)] = 0 return (ParameterMap(pd, affine=mean_affine, unit=None), ParameterMap(r1, affine=mean_affine, unit='1/s'), ParameterMap(mtsat, affine=mean_affine, unit='%'))
def _prepare(data, dist, opt): # --- options ------------------------------------------------------ # we deepcopy all options so that we can overwrite/simplify them in place opt = ESTATICSOptions().update(opt).cleanup_() backend = dict(dtype=opt.backend.dtype, device=opt.backend.device) # --- be polite ---------------------------------------------------- if len(data) > 1: pstr = f'Fitting a (shared) exponential decay model with {len(data)} contrasts.' else: pstr = f'Fitting an exponential decay model.' print(pstr) print('Echo times:') for i, contrast in enumerate(data): print(f' - contrast {i:2d}: [' + ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms') # --- estimate noise / register / initialize maps ------------------ data, maps, dist = preproc(data, dist, opt) nb_contrasts = len(maps) - 1 if opt.distortion.enable: print('Readout directions:') for i, contrast in enumerate(data): layout = spatial.affine_to_layout(contrast.affine) layout = spatial.volume_layout_to_name(layout) readout = layout[contrast.readout] readout = ('left-right' if 'L' in readout or 'R' in readout else 'infra-supra' if 'I' in readout or 'S' in readout else 'antero-posterior' if 'A' in readout or 'P' in readout else 'unknown') print(f' - contrast {i:2d}: {readout}') # --- prepare regularization factor -------------------------------- # 1. Parameter maps regularization # -> we want lam = [*lam_intercepts, lam_decay] *lam, lam_decay = opt.regularization.factor lam = core.py.make_list(lam, nb_contrasts) lam.append(lam_decay) if not any(lam): opt.regularization.norm = '' opt.regularization.factor = lam # 2. Distortion fields regularization lam_dist = dict(factor=opt.distortion.factor, absolute=opt.distortion.absolute, membrane=opt.distortion.membrane, bending=opt.distortion.bending) opt.distortion.factor = lam_dist # --- initialize weights (RLS) ------------------------------------- mean_shape = maps.decay.volume.shape rls = None if opt.regularization.norm.endswith('tv'): rls_shape = mean_shape if opt.regularization.norm == 'tv': rls_shape = (len(maps), *rls_shape) rls = ParameterMap(rls_shape, fill=1, **backend).volume if opt.regularization.norm: print('Regularization:') print(f' - type: {opt.regularization.norm.upper()}') print(f' - log intercepts: [' + ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']') print(f' - decay: {lam[-1]:.3g}') else: print('Without regularization') if opt.distortion.enable: print('Distortion correction:') print(f' - model: {opt.distortion.model.lower()}') print( f' - absolute: {opt.distortion.absolute * opt.distortion.factor["factor"]}' ) print( f' - membrane: {opt.distortion.membrane * opt.distortion.factor["factor"]}' ) print( f' - bending: {opt.distortion.bending * opt.distortion.factor["factor"]}' ) else: print('Without distortion correction') # --- initialize nb of iterations ---------------------------------- if not opt.regularization.norm.endswith('tv'): # no reweighting -> do more gauss-newton updates instead opt.optim.max_iter_prm *= opt.optim.max_iter_rls opt.optim.max_iter_rls = 1 print('Optimization:') print(f' - Tolerance: {opt.optim.tolerance}') if opt.regularization.norm.endswith('tv'): print(f' - IRLS iterations: {opt.optim.max_iter_rls}') print(f' - Param iterations: {opt.optim.max_iter_prm}') if opt.distortion.enable: print(f' - Dist iterations: {opt.optim.max_iter_dist}') print(f' - FMG cycles: 2') print(f' - CG iterations: {opt.optim.max_iter_cg}' f' (tolerance: {opt.optim.tolerance_cg})') if opt.optim.nb_levels > 1: print(f' - Levels: {opt.optim.nb_levels}') # ------------------------------------------------------------------ # MAIN OPTIMIZATION LOOP # ------------------------------------------------------------------ if opt.verbose: pstr = f'{"rls":^3s} | {"gn":^3s} | {"step":^4s} | ' pstr += f'{"fit":^12s} + {"reg":^12s} + {"rls":^12s} ' if opt.distortion.enable: pstr += f'+ {"dist":^12s} ' pstr += f'= {"crit":^12s}' if opt.optim.nb_levels > 1: pstr = f'{"lvl":3s} | ' + pstr print('\n' + pstr) print('-' * len(pstr)) return data, maps, dist, opt, rls
def nonlin(data, opt=None): """Fit the ESTATICS model to multi-echo Gradient-Echo data. Parameters ---------- data : sequence[GradientEchoMulti] Observed GRE data. opt : Options, optional Algorithm options. Returns ------- intecepts : sequence[GradientEcho] Echo series extrapolated to TE=0 decay : estatics.ParameterMap R2* decay map """ opt = ESTATICSOptions().update(opt) dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) # --- be polite --- print( f'Fitting a (multi) exponential decay model with {len(data)} contrasts. Echo times:' ) for i, contrast in enumerate(data): print(f' - contrast {i:2d}: [' + ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms') # --- estimate noise / register / initialize maps --- data, maps = preproc(data, opt) print(maps.affine) vx = spatial.voxel_size(maps.affine) # --- prepare regularization factor --- lam = opt.regularization.factor lam = core.utils.make_list(lam) if len(lam) > 1: *lam, lam_decay = lam else: lam_decay = lam[0] lam = core.utils.make_list(lam, len(maps) - 1) lam.append(lam_decay) # --- initialize weights (RLS) --- if (not opt.regularization.norm or opt.regularization.norm.lower() == 'none' or all(l == 0 for l in lam)): opt.regularization.norm = '' opt.regularization.norm = opt.regularization.norm.lower() mean_shape = maps.decay.volume.shape rls = None sumrls = 0 if opt.regularization.norm in ('tv', 'jtv'): rls_shape = mean_shape if opt.regularization.norm == 'tv': rls_shape = (len(maps), ) + rls_shape rls = ParameterMap(rls_shape, fill=1, **backend).volume sumrls = 0.5 * rls.sum(dtype=torch.double) if opt.regularization.norm: print(f'With {opt.regularization.norm.upper()} regularization:') print(' - log intercepts: [' + ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']') print(f' - decay: {lam[-1]:.3g}') else: print('Without regularization:') # --- compute derivatives --- grad = torch.empty((len(data) + 1, ) + mean_shape, **backend) hess = torch.empty((len(data) * 2 + 1, ) + mean_shape, **backend) if opt.regularization.norm not in ('tv', 'jtv'): # no reweighting -> do more gauss-newton updates instead opt.optim.max_iter_gn *= opt.optim.max_iter_rls opt.optim.max_iter_rls = 1 if opt.verbose: print('{:^3s} | {:^3s} | {:^12s} + {:^12s} + {:^12s} = {:^12s}'.format( 'rls', 'gn', 'fit', 'reg', 'rls', 'crit')) ll_rls = [] ll_max = core.constants.ninf for n_iter_rls in range(opt.optim.max_iter_rls): multi_rls = rls if opt.regularization.norm == 'tv' \ else [rls] * len(maps) # --- Gauss Newton loop --- ll_gn = [] for n_iter_gn in range(opt.optim.max_iter_gn): decay = maps.decay crit = 0 grad.zero_() hess.zero_() # --- loop over contrasts --- for i, (contrast, intercept) in enumerate(zip(data, maps.intercepts)): # compute gradient crit1, g1, h1 = _nonlin_gradient(contrast, intercept, decay, opt) # increment gind = [i, -1] grad[gind, ...] += g1 hind = [2 * i, -1, 2 * i + 1] hess[hind, ...] += h1 crit += crit1 # --- regularization --- reg = 0. if opt.regularization.norm: for i, (map, weight, l) in enumerate(zip(maps, multi_rls, lam)): if not l: continue reg1, g1 = _nonlin_reg(map.volume, vx, weight, l) reg += reg1 grad[i] += g1 # --- gauss-newton --- if not hess.isfinite().all(): print('WARNING: NaNs in hess (??)') if opt.regularization.norm: hess = hessian_loaddiag(hess, 1e-6, 1e-8) deltas = _nonlin_solve(hess, grad, multi_rls, lam, vx, opt) else: hess = hessian_loaddiag(hess, 1e-6, 1e-8) deltas = hessian_solve(hess, grad) if not deltas.isfinite().all(): print('WARNING: NaNs in delta (non stable Hessian)') for map, delta in zip(maps, deltas): map.volume -= delta if map.min is not None or map.max is not None: map.volume.clamp_(map.min, map.max) # --- Compute gain --- ll = crit + reg + sumrls ll_max = max(ll_max, ll) ll_prev = ll_gn[-1] if ll_gn else ll_max gain = (ll_prev - ll) / (ll_max - ll_prev) ll_gn.append(ll) if opt.verbose: print( '{:3d} | {:3d} | {:12.6g} + {:12.6g} + {:12.6g} = {:12.6g} | gain = {:7.2g}' .format(n_iter_rls, n_iter_gn, crit, reg, sumrls, crit + reg + sumrls, gain)) if gain < opt.optim.tolerance_gn: break # --- Update RLS weights --- if opt.regularization.norm in ('tv', 'jtv'): rls = _nonlin_rls(maps, lam, opt.regularization.norm) sumrls = 0.5 * rls.sum(dtype=torch.double) eps = core.constants.eps(rls.dtype) rls = rls.clamp_min_(eps).reciprocal_() # --- Compute gain --- # (we are late by one full RLS iteration when computing the # gain but we save some computations) ll = ll_gn[-1] ll_prev = ll_rls[-1][-1] if ll_rls else ll_max ll_rls.append(ll_gn) gain = (ll_prev - ll) / (ll_max - ll_prev) if gain < opt.optim.tolerance_rls: print(f'Converged ({gain:7.2g})') break # --- Prepare output --- return postproc(maps, data)
def decay(self): if self.volume is None: return None return ParameterMap(self.volume[-1], affine=self.affine)