Esempio n. 1
0
 def solve(h, g, w, optim):
     return spatial.solve_field(h,
                                g,
                                dim=dim,
                                weights=w,
                                **prm,
                                optim=optim,
                                max_iter=sub_iter,
                                tolerance=sub_tol)
Esempio n. 2
0
def _chi_fit_tv(dat, sigma=None, df=None, lam=10, max_iter=50, tol=1e-5):

    noise, tissue = estimate_noise(dat, chi=True)
    mu = tissue['mean']
    sigma = sigma or noise['sd']
    df = df or noise['dof']
    print(f'sigma = {sigma}, dof = {df}, mu = {mu}')
    isigma2 = 1 / (sigma * sigma)
    lam = lam / mu

    msk = torch.isfinite(dat)
    if msk.any():
        dat = dat.masked_fill(msk.bitwise_not_(), 0)
    fit = chi_bias_correction(dat, sigma, df)[0]
    msk = dat > 0
    n = msk.sum()

    w = torch.ones_like(dat)
    h = w.new_full([1] * dat.dim(), isigma2)[None]

    ll_prev = float('inf')
    for n_iter in range(max_iter):

        w, tv = spatial.membrane_weights(fit[None],
                                         factor=lam,
                                         return_sum=True)

        l, delta = nll_chi(dat, fit, msk, isigma2, df)
        delta += spatial.regulariser(fit[None], membrane=lam, weights=w)[0]
        delta = spatial.solve_field(h, delta[None], membrane=lam, weights=w)[0]
        fit.sub_(delta).clamp_min_(1e-8 * sigma)

        ll = l + tv
        gain, ll_prev = ll_prev - ll, ll
        print(f'{n_iter:3d} | {l/n:12.6g} + {tv/n:12.6g} = {ll/n:12.6g} '
              f'| gain = {gain/n:6.3}')
        if abs(gain) < tol * n:
            break

    return fit, sigma, df
Esempio n. 3
0
def greeq(data, transmit=None, receive=None, opt=None, **kwopt):
    """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : GREEQOptions or dict, optional
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'optim':   {'nb_levels':     1,         # Number of pyramid levels
                     'max_iter_rls':  10,        # Max reweighting iterations
                     'max_iter_gn':   5,         # Max Gauss-Newton iterations
                     'max_iter_cg':   32,        # Max Conjugate-Gradient iterations
                     'tolerance': 1e-05,     # Tolerance for early stopping (RLS)
                     'tolerance':  1e-05,         ""
                     'tolerance_cg':  1e-03},        ""
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'penalty': {'norm':    'jtv',           # Type of penalty: {'tkh', 'tv', 'jtv', None}
                     'factor':  {'r1':  10,      # Penalty factor per (log) map
                                 'pd':  10,
                                 'r2s': 2,
                                 'mt':  2}},
         'verbose': 1}

    Returns
    -------
    pd : ParameterMap
        Proton density
    r1 : ParameterMap
        Longitudinal relaxation rate
    r2s : ParameterMap
        Apparent transverse relaxation rate
    mt : ParameterMap, optional
        Magnetisation transfer saturation
        Only returned is MT-weighted data is provided.

    """
    opt = GREEQOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate noise / register / initialize maps ---
    data, transmit, receive, maps = preproc(data, transmit, receive, opt)
    has_mt = hasattr(maps, 'mt')

    # --- prepare penalty factor ---
    lam = opt.penalty.factor
    if isinstance(lam, dict):
        lam = [
            lam.get('pd', 0),
            lam.get('r1', 0),
            lam.get('r2s', 0),
            lam.get('mt', 0)
        ]
        if not has_mt:
            lam = lam[:3]
    lam = core.utils.make_vector(lam, 3 + has_mt,
                                 **backend)  # PD, R1, R2*, [MT]

    # --- initialize weights (RLS) ---
    if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0):
        opt.penalty.norm = ''
    opt.penalty.norm = opt.penalty.norm.lower()
    mean_shape = maps[0].shape
    rls = None
    sumrls = 0
    if opt.penalty.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.penalty.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = torch.ones(rls_shape, **backend)
        sumrls = 0.5 * core.py.prod(rls_shape)

    if opt.penalty.norm:
        print(f'With {opt.penalty.norm.upper()} penalty:')
        print(f'    - PD:  {lam[0]:.3g}')
        print(f'    - R1:  {lam[1]:.3g}')
        print(f'    - R2*: {lam[2]:.3g}')
        if has_mt:
            print(f'    - MT:  {lam[3]:.3g}')
    else:
        print('Without penalty')

    if opt.penalty.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.penalty.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - GN iterations:    {opt.optim.max_iter_gn}')
    if opt.optim.solver == 'fmg':
        print(f'    - FMG cycles:       2')
        print(f'    - CG iterations:    2')
    else:
        print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
              f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    printer = CritPrinter(max_levels=opt.optim.nb_levels,
                          max_rls=opt.optim.max_iter_rls,
                          max_gn=opt.optim.max_iter_gn,
                          penalty=opt.penalty.norm,
                          verbose=opt.verbose)
    printer.print_head()

    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    vx0 = vx = spatial.voxel_size(aff0)
    vol0 = vx0.prod()
    vol = vx.prod() / vol0
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    for level in range(opt.optim.nb_levels, 0, -1):
        printer.level = level

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            vol = vx.prod() / vol0
            maps, rls = resize(maps, rls, aff, shape)
            if opt.penalty.norm in ('tv', 'jtv'):
                sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double)

        # --- compute derivatives ---
        nb_prm = len(maps)
        nb_hes = nb_prm * (nb_prm + 1) // 2
        grad = torch.empty((nb_prm, ) + shape, **backend)
        hess = torch.empty((nb_hes, ) + shape, **backend)

        ll_rls = []
        ll_gn = []
        ll_max = float('inf')

        max_iter_rls = max(opt.optim.max_iter_rls // level, 1)
        for n_iter_rls in range(max_iter_rls):
            # --- Reweighted least-squares loop ---
            printer.rls = n_iter_rls

            # --- Gauss Newton loop ---
            for n_iter_gn in range(opt.optim.max_iter_gn):
                # start = time.time()
                printer.gn = n_iter_gn
                crit = 0
                grad.zero_()
                hess.zero_()
                # --- loop over contrasts ---
                for contrast, b1m, b1p in zip(data, receive, transmit):
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, maps, b1m, b1p, opt)

                    # increment
                    if hasattr(maps, 'mt') and not contrast.mt:
                        # we optimize for mt but this particular contrast
                        # has no information about mt so g1/h1 are smaller
                        # than grad/hess.
                        grad[:-1] += g1
                        hind = list(range(nb_prm - 1))
                        cnt = nb_prm
                        for i in range(nb_prm):
                            for j in range(i + 1, nb_prm):
                                if i != nb_prm - 1 and j != nb_prm - 1:
                                    hind.append(cnt)
                                cnt += 1
                        hess[hind] += h1
                        crit += crit1
                    else:
                        grad += g1
                        hess += h1
                        crit += crit1

                    del g1, h1, crit1
                # duration = time.time() - start
                # print('grad', duration)

                # start = time.time()
                reg = 0.
                if opt.penalty.norm:
                    g = spatial.regulariser(maps.volume,
                                            weights=rls,
                                            dim=3,
                                            voxel_size=vx,
                                            membrane=1,
                                            factor=lam * vol)
                    grad += g
                    reg = 0.5 * dot(maps.volume, g)
                    del g
                # duration = time.time() - start
                # print('reg', duration)

                # --- gauss-newton ---
                # start = time.time()
                grad = check_nans_(grad, warn='gradient')
                hess = check_nans_(hess, warn='hessian')
                if opt.penalty.norm:
                    # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8)  # 1e-5 1e-8
                    if opt.optim.solver == 'fmg':
                        deltas = spatial.solve_field_fmg(hess,
                                                         grad,
                                                         rls,
                                                         factor=lam * vol,
                                                         membrane=1,
                                                         voxel_size=vx,
                                                         nb_iter=2)
                    else:
                        deltas = spatial.solve_field(
                            hess,
                            grad,
                            rls,
                            factor=lam * vol,
                            membrane=1,
                            voxel_size=vx,
                            verbose=max(0, opt.verbose - 1),
                            optim='cg',
                            max_iter=opt.optim.max_iter_cg,
                            tolerance=opt.optim.tolerance_cg,
                            stop='diff')
                else:
                    # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4)
                    deltas = spatial.solve_field_closedform(hess, grad)
                deltas = check_nans_(deltas, warn='deltas')
                # duration = time.time() - start
                # print('solve', duration)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    map.volume.clamp_(-64, 64)  # avoid exp overflow
                    del delta
                del deltas

                # --- Compute gain ---
                ll = crit + reg + sumrls
                ll_max = max(ll_max, ll)
                ll_prev = ll_gn[-1] if ll_gn else ll_max
                gain = ll_prev - ll
                ll_gn.append(ll)
                printer.print_crit(crit, reg, sumrls, gain / ll_scl)
                if gain < opt.optim.tolerance * ll_scl:
                    # print('GN converged: ', ll_prev.item(), '->', ll.item())
                    break

            # --- Update RLS weights ---
            if opt.penalty.norm in ('tv', 'jtv'):
                rls, sumrls = spatial.membrane_weights(
                    maps.volume,
                    lam,
                    vx,
                    return_sum=True,
                    dim=3,
                    joint=opt.penalty.norm == 'jtv',
                    eps=core.constants.eps(rls.dtype))
                sumrls *= 0.5 * vol

                # --- Compute gain ---
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll = ll_gn[-1]
                ll_prev = ll_rls[-1] if ll_rls else ll_max
                ll_rls.append(ll)
                gain = ll_prev - ll
                if abs(gain) < opt.optim.tolerance * ll_scl:
                    # print(f'RLS converged ({gain:7.2g})')
                    break

    del grad
    if opt.uncertainty:
        uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt)
        maps.pd.uncertainty = uncertainty[0]
        maps.r1.uncertainty = uncertainty[1]
        maps.r2s.uncertainty = uncertainty[2]
        if hasattr(maps, 'mt'):
            maps.mt.uncertainty = uncertainty[3]

    # --- Prepare output ---
    return postproc(maps)
Esempio n. 4
0
def run_epic_noreverse(echoes,
                       voxshift=None,
                       lam=1,
                       sigma=1,
                       max_iter=(10, 32),
                       tol=1e-5,
                       verbose=False):
    """Run EPIC on pre-loaded tensors (no reverse gradients or synthesis)

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # initialize denoised echoes
    fit = echoes.clone()
    fwd_fit = torch.empty_like(fit)

    # prepare voxel shift maps
    do_voxshift = voxshift is not None
    if do_voxshift:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if do_voxshift:
        one = torch.empty_like(voxshift)[None]
        h = torch.empty_like(echoes)
        h[0::2] = push1d(pull1d(one, voxshift), voxshift)
        h[1::2] = push1d(pull1d(one, ivoxshift), ivoxshift)
        del one
    else:
        h = fit.new_ones([1] * (nd + 1))
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        ll = ssq(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        g = fwd_fit.mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
Esempio n. 5
0
def run_epic(echoes,
             reverse_echoes=None,
             voxshift=None,
             extrapolate=True,
             lam=1,
             sigma=1,
             max_iter=(10, 32),
             tol=1e-5,
             verbose=False):
    """Run EPIC on pre-loaded tensors.

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    reverse_echoes : (N, *spatial)
        Echoes acquired with reverse bipolar readout. Else: synthesized.
    voxshift : (*spatial) tensor
        Voxel shift map used to deform towards even (0, 2, ...) echoes.
        Its inverse is used to deform towards odd (1, 3, ...) echoes.
    extrapolate : bool
        Extrapolate first/last echo when reverse_echoes is None.
        Otherwise, only use interpolated echoes.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    if reverse_echoes is False:
        return run_epic_noreverse(echoes, voxshift, lam, sigma, max_iter, tol,
                                  verbose)

    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # synthesize echoes
    synth = not torch.is_tensor(reverse_echoes)
    if synth:
        neg = synthesize_neg(echoes[0::2])
        pos = synthesize_pos(echoes[1::2])
        reverse_echoes = torch.stack([x for y in zip(pos, neg) for x in y])
        del pos, neg
    else:
        extrapolate = True

    # initialize denoised echoes
    fit = (echoes + reverse_echoes).div_(2)
    if not extrapolate:
        fit[0] = echoes[0]
        fit[-1] = echoes[-1]
    fwd_fit = torch.zeros_like(fit)
    bwd_fit = torch.zeros_like(fit)

    # prepare voxel shift maps
    if voxshift is not None:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if voxshift is not None:
        one = torch.ones_like(voxshift)[None]
        if extrapolate:
            h = push1d(pull1d(one, voxshift), voxshift)
            h += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h = torch.zeros_like(fit)
            h[:-1] += push1d(pull1d(one, voxshift), voxshift)
            h[1:] += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
        del one
        weight_(h)
    else:
        h = fit.new_ones([ne] + [1] * nd)
        if extrapolate:
            h *= 2
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h[1:-1] *= 2
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
    weight_(h)
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood (forward)
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        halfweight_(fwd_fit)
        ll = ssq(fwd_fit)
        halfweight_(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        # gradient of likelihood (reversed)
        pull_backward(fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)
        if extrapolate:
            bwd_fit.sub_(reverse_echoes)
        else:
            bwd_fit[1:-1].sub_(reverse_echoes[1:-1])
        halfweight_(bwd_fit)
        ll += ssq(bwd_fit)
        halfweight_(bwd_fit)
        push_backward(bwd_fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)

        g = fwd_fit.add_(bwd_fit).mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
Esempio n. 6
0
def meetup(data, dist=None, opt=None):
    """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion]
        B0-induced distortion fields

    """
    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit, vreg, reg = float('inf'), 0, 0
    ll_rls = crit + vreg + reg
    sumrls_prev = sumrls
    rls_changed = False
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    armijo_prev = 1
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer_prm = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve_prm = lambda H, g: spatial.solve_field(
                H,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Initial update of parameter maps
            # ----------------------------------------------------------
            crit_pre_prm = None
            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i]
                    grad[gind] += g1
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer_prm(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                if n_iter_prm == 1:
                    crit_pre_prm = crit

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.

                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | '
                            f'{crit_prev:12.6g} + {reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve_prm(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                dd = regularizer_prm(deltas)
                dv = dot(dd, maps.volume)
                dd = dot(dd, deltas)
                delta_reg = 0.5 * (dd - 2 * dv)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- track parameter map improvement --------------
                gain = (crit_prev + reg_prev) - (crit + reg)
                if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # ----------------------------------------------------------
            #    Distortion update with line search
            # ----------------------------------------------------------
            max_iter_dist = opt.optim.max_iter_dist
            for n_iter_dist in range(1, max_iter_dist + 1):

                crit = 0
                vreg = 0
                new_crit = 0
                new_vreg = 0

                deltas, dd, dv = [], 0, 0
                # --- loop over contrasts ------------------------------
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):

                    # --- helpers --------------------------------------
                    vxr = distortion.voxel_size[contrast.readout]
                    lam_dist = dict(opt.distortion.factor)
                    lam_dist['factor'] *= vxr**2
                    regularizer_dist = lambda x: spatial.regulariser(
                        x[None],
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)[0]
                    solve_dist = lambda h, g: spatial.solve_field_fmg(
                        h,
                        g,
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)

                    # --- likelihood -----------------------------------
                    crit1, g, h = derivatives_distortion(
                        contrast, distortion, intercept, maps.decay, opt)
                    crit += crit1

                    # --- regularization -------------------------------
                    vol = distortion.volume
                    g1 = regularizer_dist(vol)
                    vreg1 = 0.5 * vol.flatten().dot(g1.flatten())
                    vreg += vreg1
                    g += g1
                    del g1

                    # --- gauss-newton ---------------------------------
                    h = check_nans_(h, warn='hessian (distortion)')
                    h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0]
                    delta = solve_dist(h, g)
                    delta = check_nans_(delta, warn='delta (distortion)')
                    deltas.append(delta)
                    del g, h

                    deltas.append(delta)
                    dd1 = regularizer_dist(delta)
                    dv += dot(dd1, vol)
                    dd1 = dot(dd1, delta)
                    dd += dd1
                    del delta, vol

                # --- track parameters improvement ---------------------
                if opt.verbose and n_iter_dist == 1:
                    gain = crit_pre_prm - (crit + delta_reg)
                    evol = '<=' if gain > 0 else '>'
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | '
                            f'{crit:12.6g} + {reg + delta_reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | '
                            f'gain = {gain / ll_scl:7.2g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- line search ----------------------------------
                reg = reg + delta_reg
                new_crit, new_reg = crit, reg
                armijo, armijo_prev, ok = armijo_prev, 0, False
                maps0 = maps
                for n_ls in range(1, 12 + 1):
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev))
                    armijo_prev = armijo
                    new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv)

                    maps = maps0.deepcopy()
                    max_iter_prm = opt.optim.max_iter_prm
                    for n_iter_prm in range(1, max_iter_prm + 1):

                        crit_prev = new_crit
                        reg_prev = new_reg

                        # --- loop over contrasts ------------------
                        new_crit = 0
                        grad.zero_()
                        hess.zero_()
                        for i, (contrast, intercept, distortion) in enumerate(
                                zip(data, maps.intercepts, dist)):
                            # compute gradient
                            new_crit1, g1, h1 = derivatives_parameters(
                                contrast, distortion, intercept, maps.decay,
                                opt)
                            # increment
                            gind, hind = [i, -1
                                          ], [i,
                                              len(grad) - 1,
                                              len(grad) + i]
                            grad[gind] += g1
                            hess[hind] += h1
                            new_crit += new_crit1
                        del g1, h1

                        # --- regularization -----------------------
                        new_reg = 0.
                        if opt.regularization.norm:
                            g1 = regularizer_prm(maps.volume)
                            new_reg = 0.5 * dot(maps.volume, g1)
                            grad += g1
                            del g1

                        new_gain = (crit_prev + reg_prev) - (new_crit +
                                                             new_reg)
                        if new_gain < opt.optim.tolerance * ll_scl:
                            break

                        # --- gauss-newton -------------------------
                        # Computing the GN step involves solving H\g
                        hess = check_nans_(hess, warn='hessian')
                        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                        delta = solve_prm(hess, grad)
                        delta = check_nans_(delta, warn='delta')

                        dd = regularizer_prm(delta)
                        dv = dot(dd, maps.volume)
                        dd = dot(dd, delta)
                        delta_reg = 0.5 * (dd - 2 * dv)

                        for map, delta1 in zip(maps, delta):
                            map.volume -= delta1
                            if map.min is not None or map.max is not None:
                                map.volume.clamp_(map.min, map.max)
                        del delta

                    if new_crit + new_reg + new_vreg <= crit + reg:
                        ok = True
                        break
                    else:
                        armijo = armijo / 2

                if not ok:
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.add_(delta1, alpha=armijo_prev)
                    armijo_prev = 1
                    maps = maps0
                    new_crit = crit
                    new_vreg = 0
                    del delta
                else:
                    armijo_prev *= 1.5
                new_vreg = vreg + new_vreg

                # --- track distortion improvement ---------------------
                if not ok:
                    evol = '== (x)'
                elif new_crit + new_vreg <= crit + vreg:
                    evol = f'<= ({n_ls:2d})'
                else:
                    evol = f'> ({n_ls:2d})'
                gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg)
                crit, reg, reg = new_crit, new_vreg, new_reg
                if opt.verbose:
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | '
                        f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if not ok:
                    break

                if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = crit + reg + vreg
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'RLS converged ({gain / ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
Esempio n. 7
0
def nonlin(data, dist=None, opt=None):
    """Fit the ESTATICS model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion], if opt.distortion.enable
        B0-induced distortion fields

    """
    if opt.distortion.enable:
        return meetup(data, dist, opt)

    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit = float('inf')
    vreg = 0
    reg = 0
    sumrls_prev = sumrls
    rls_changed = False
    ll_gn = ll_rls = float('inf')
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        #   > max_iter_rls == 1 if regularization is not (J)TV
        max_iter_rls = opt.optim.max_iter_rls
        if level != 1:
            max_iter_rls = 1
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve = lambda h, g: spatial.solve_field(
                h,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Update parameter maps
            # ----------------------------------------------------------

            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind = [i, -1]
                    grad[gind] += g1
                    hind = [i, len(grad) - 1, len(grad) + i]
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.
                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {"---":3s} | {"rls":4s} | '
                        f'{crit_prev:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    if opt.distortion.enable:
                        pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                for map, delta in zip(maps, deltas):
                    map.volume -= delta

                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- compute GN gain ----------------------------------
                ll_gn_prev = ll_gn
                ll_gn = crit + reg + vreg + sumrls
                gain = ll_gn_prev - ll_gn
                if opt.verbose:
                    pstr = f'{n_iter_rls:3d} | {n_iter_prm:3d} | '
                    if opt.distortion.enable:
                        pstr += f'{"----":4s} | '
                        pstr += f'{"-"*72:72s} | '
                    else:
                        pstr += f'{"prm":4s} | '
                        pstr += f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} '
                        pstr += f'= {ll_gn:12.6g} | '
                    pstr += f'gain = {gain/ll_scl:7.2g}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                reg = 0.5 * dot(maps.volume, regularizer(maps.volume))
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = ll_gn
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'Converged ({gain/ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
Esempio n. 8
0
def flash_b1(x,
             fa,
             tr,
             lam=(0, 0, 0),
             penalty=('m', 'm', 'm'),
             chi=True,
             pd=None,
             r1=None,
             b1=None):
    """Estimate B1+ from Variable flip angle data

    Parameters
    ----------
    x : (C, *spatial) tensor
        Input flash images
    fa : (C,) sequence[float]
        Flip angle (in deg)
    tr : float or (C,) sequence[float]
        Repetition time (in sec)
    lam : (float, float, float), default=(0, 0, 10)
        Regularization value for the T2*w Signal, T1  map and B1 map.
    penalty : 3x {'membrane', 'bending'}, default=('m', 'm', 'b')
        Regularization type for the T2*w Signal, T1  map and B1 map.

    Returns
    -------
    s : (*spatial) tensor
        T2*-weighted PD map
    r1 : (*spatial) tensor
        R1 map
    b1 : (*spatial) tenmsor
        B1+ map

    """

    tr = py.make_list(tr, len(x))
    fa = utils.make_vector(fa, len(x), dtype=torch.double)
    fa = fa.mul_(pymath.pi / 180).tolist()
    lam = py.make_list(lam, 3)
    penalty = py.make_list(penalty, 3)

    sd, df, mu = 0, 0, 0
    for x1 in x:
        bg, fg = estimate_noise(x1, chi=True)
        sd += bg['sd'].log()
        df += bg['dof'].log()
        mu += fg['mean'].log()
    sd = (sd / len(x)).exp()
    df = (df / len(x)).exp()
    mu = (mu / len(x)).exp()
    prec = 1 / (sd * sd)
    if not chi:
        df = 1

    shape = x.shape[1:]
    theta = x.new_empty([3, *shape])
    theta[0] = mu.log() if pd is None else pd.log()
    theta[1] = 1 if r1 is None else r1.log()
    theta[2] = 0 if b1 is None else b1.log()
    n = (x != 0).sum()

    g = torch.zeros_like(theta)
    h = theta.new_zeros([6, *theta.shape[1:]])
    g1 = torch.zeros_like(theta)
    h1 = theta.new_zeros([6, *theta.shape[1:]])

    prm = dict(
        membrane=(lam[0] if penalty[0][0] == 'm' else 0,
                  lam[1] if penalty[1][0] == 'm' else 0,
                  lam[2] if penalty[2][0] == 'm' else 0),
        bending=(lam[0] if penalty[0][0] == 'b' else 0,
                 lam[1] if penalty[1][0] == 'b' else 0,
                 lam[2] if penalty[2][0] == 'b' else 0),
    )

    ll0 = lr0 = float('inf')
    iter_start = 3
    for n_iter in range(32):

        if df > 1 and n_iter == iter_start:
            ll0 = lr0 = float('inf')

        # derivatives of likelihood term
        df1 = 1 if n_iter < iter_start else df
        ll, g, h = derivatives(x, fa, tr, theta[0], theta[1], theta[2], prec,
                               df1, g, h, g1, h1)

        # derivatives of regularization term
        reg = spatial.regulariser(theta, **prm, absolute=1e-10)
        g += reg
        lr = 0.5 * dot(theta, reg)

        l, l0 = ll + lr, ll0 + lr0
        gain = l0 - l
        print(
            f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}'
        )

        # Gauss-Newton update
        h[:3] += 1e-8 * h[:3].abs().max(0).values
        h[:3] += 1e-5
        if n_iter % 2:
            delta = torch.zeros_like(theta)
            prm1 = dict(membrane=prm['membrane'][:2],
                        bending=prm['bending'][:2])
            hh = torch.stack([h[0], h[1], h[3]])
            delta[:2] = spatial.solve_field(hh,
                                            g[:2],
                                            **prm1,
                                            max_iter=16,
                                            absolute=1e-10)
            del hh
        else:
            delta = torch.zeros_like(theta)
            prm1 = dict(membrane=prm['membrane'][-1],
                        bending=prm['bending'][-1])
            delta[-1:] = spatial.solve_field(h[2:3],
                                             g[2:3],
                                             **prm1,
                                             max_iter=16,
                                             absolute=1e-10)
        # delta = spatial.solve_field(h, g, **prm, max_iter=16, absolute=1e-10)
        # theta -= spatial.solve_field_fmg(h, g, **prm)

        # line search
        dd = spatial.regulariser(delta, **prm, absolute=1e-10)
        dt = dot(dd, theta)
        dd = dot(dd, delta)
        success = False
        armijo = 1
        theta0 = theta
        ll0, lr0 = ll, lr
        for n_ls in range(12):
            theta = theta0.sub(delta, alpha=armijo)
            ll = nll(x, fa, tr, *theta, prec, df1)
            lr = 0.5 * armijo * (armijo * dd - 2 * dt)
            if ll + lr < ll0:  # and theta[1].max() < 0.69:
                print(n_ls, 'success', ((ll + lr) / n).item(),
                      (ll0 / n).item())
                success = True
                break
            print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item())
            armijo /= 2
        if not success and n_iter > 5:
            theta = theta0
            break

        # delta = spatial.solve_field_fmg(h, g, **prm)
        #
        # # line search
        # dd = spatial.regulariser(delta, **prm)
        # dt = dot(dd, theta)
        # dd = dot(dd, delta)
        # success = False
        # armijo = 1
        # theta0 = theta
        # ll0, lr0 = ll, lr
        # for n_ls in range(12):
        #     theta = theta0.sub(delta, alpha=armijo)
        #     ll = nll(x, fa, tr, theta[0], theta[1], theta[2]) * prec
        #     lr = 0.5 * armijo * (armijo * dd - 2 * dt)
        #     if ll + lr < ll0:
        #         print('success', n_ls)
        #         success = True
        #         break
        #     armijo /= 2
        # if not success:
        #     theta = theta0
        #     break

        import matplotlib.pyplot as plt
        # plt.subplot(2, 2, 1)
        # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('PD * exp(-TE * R2*)')
        # plt.subplot(2, 2, 2)
        # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('R1')
        # plt.subplot(2, 2, 3)
        # plt.imshow(theta[2, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('B1+')
        # plt.show()

        vmin, vmax = 0, 2 * mu
        y = flash_signals(fa, tr, *theta)
        ex = 3
        plt.rcParams["figure.figsize"] = (4, len(x) + ex)
        for i in range(len(x)):
            plt.subplot(len(x) + ex, 4, 4 * i + 1)
            plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 2)
            plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 3)
            plt.imshow(x[i, ..., x.shape[-1] // 2] -
                       y[i, ..., y.shape[1] // 2],
                       cmap=plt.get_cmap('coolwarm'))
            plt.axis('off')
            plt.colorbar()
            plt.subplot(len(x) + ex, 4, 4 * i + 4)
            plt.imshow(
                (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi)
            plt.axis('off')
            plt.colorbar()
        all_fa = torch.linspace(0, 2 * pymath.pi, 512)
        loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2),
               (2 * theta.shape[1] // 3, theta.shape[2] // 2,
                theta.shape[3] // 2),
               (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)]
        for j, (nx, ny, nz) in enumerate(loc):
            plt.subplot(len(x) + ex, 1, len(x) + j + 1)
            plt.plot(
                all_fa,
                flash_signals(all_fa, tr[:1] * 512, *theta[:, nx, ny, nz]))
            plt.scatter(fa, x[:, nx, ny, nz])
        plt.show()

        if gain < 1e-4 * n:
            break

    theta.exp_()
    return theta[0], theta[1], theta[2]