예제 #1
0
 def momentum_prm(self, dat):
     """Momentum of the parameter maps"""
     return spatial.regulariser(dat,
                                weights=self.rls,
                                dim=3,
                                **self.lam_prm,
                                voxel_size=self.voxel_size)
예제 #2
0
 def momentum_dist(self, dat, vx, readout):
     """Momentum of the distortion maps"""
     lam = dict(self.lam_dist)
     lam['factor'] = lam['factor'] * (vx[readout]**2)
     return spatial.regulariser(dat[None],
                                **self.lam_dist,
                                dim=3,
                                bound=self.DIST_BOUND,
                                voxel_size=vx)[0]
예제 #3
0
def _chi_fit_tv(dat, sigma=None, df=None, lam=10, max_iter=50, tol=1e-5):

    noise, tissue = estimate_noise(dat, chi=True)
    mu = tissue['mean']
    sigma = sigma or noise['sd']
    df = df or noise['dof']
    print(f'sigma = {sigma}, dof = {df}, mu = {mu}')
    isigma2 = 1 / (sigma * sigma)
    lam = lam / mu

    msk = torch.isfinite(dat)
    if msk.any():
        dat = dat.masked_fill(msk.bitwise_not_(), 0)
    fit = chi_bias_correction(dat, sigma, df)[0]
    msk = dat > 0
    n = msk.sum()

    w = torch.ones_like(dat)
    h = w.new_full([1] * dat.dim(), isigma2)[None]

    ll_prev = float('inf')
    for n_iter in range(max_iter):

        w, tv = spatial.membrane_weights(fit[None],
                                         factor=lam,
                                         return_sum=True)

        l, delta = nll_chi(dat, fit, msk, isigma2, df)
        delta += spatial.regulariser(fit[None], membrane=lam, weights=w)[0]
        delta = spatial.solve_field(h, delta[None], membrane=lam, weights=w)[0]
        fit.sub_(delta).clamp_min_(1e-8 * sigma)

        ll = l + tv
        gain, ll_prev = ll_prev - ll, ll
        print(f'{n_iter:3d} | {l/n:12.6g} + {tv/n:12.6g} = {ll/n:12.6g} '
              f'| gain = {gain/n:6.3}')
        if abs(gain) < tol * n:
            break

    return fit, sigma, df
예제 #4
0
def greeq(data, transmit=None, receive=None, opt=None, **kwopt):
    """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : GREEQOptions or dict, optional
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'optim':   {'nb_levels':     1,         # Number of pyramid levels
                     'max_iter_rls':  10,        # Max reweighting iterations
                     'max_iter_gn':   5,         # Max Gauss-Newton iterations
                     'max_iter_cg':   32,        # Max Conjugate-Gradient iterations
                     'tolerance': 1e-05,     # Tolerance for early stopping (RLS)
                     'tolerance':  1e-05,         ""
                     'tolerance_cg':  1e-03},        ""
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'penalty': {'norm':    'jtv',           # Type of penalty: {'tkh', 'tv', 'jtv', None}
                     'factor':  {'r1':  10,      # Penalty factor per (log) map
                                 'pd':  10,
                                 'r2s': 2,
                                 'mt':  2}},
         'verbose': 1}

    Returns
    -------
    pd : ParameterMap
        Proton density
    r1 : ParameterMap
        Longitudinal relaxation rate
    r2s : ParameterMap
        Apparent transverse relaxation rate
    mt : ParameterMap, optional
        Magnetisation transfer saturation
        Only returned is MT-weighted data is provided.

    """
    opt = GREEQOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate noise / register / initialize maps ---
    data, transmit, receive, maps = preproc(data, transmit, receive, opt)
    has_mt = hasattr(maps, 'mt')

    # --- prepare penalty factor ---
    lam = opt.penalty.factor
    if isinstance(lam, dict):
        lam = [
            lam.get('pd', 0),
            lam.get('r1', 0),
            lam.get('r2s', 0),
            lam.get('mt', 0)
        ]
        if not has_mt:
            lam = lam[:3]
    lam = core.utils.make_vector(lam, 3 + has_mt,
                                 **backend)  # PD, R1, R2*, [MT]

    # --- initialize weights (RLS) ---
    if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0):
        opt.penalty.norm = ''
    opt.penalty.norm = opt.penalty.norm.lower()
    mean_shape = maps[0].shape
    rls = None
    sumrls = 0
    if opt.penalty.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.penalty.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = torch.ones(rls_shape, **backend)
        sumrls = 0.5 * core.py.prod(rls_shape)

    if opt.penalty.norm:
        print(f'With {opt.penalty.norm.upper()} penalty:')
        print(f'    - PD:  {lam[0]:.3g}')
        print(f'    - R1:  {lam[1]:.3g}')
        print(f'    - R2*: {lam[2]:.3g}')
        if has_mt:
            print(f'    - MT:  {lam[3]:.3g}')
    else:
        print('Without penalty')

    if opt.penalty.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.penalty.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - GN iterations:    {opt.optim.max_iter_gn}')
    if opt.optim.solver == 'fmg':
        print(f'    - FMG cycles:       2')
        print(f'    - CG iterations:    2')
    else:
        print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
              f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    printer = CritPrinter(max_levels=opt.optim.nb_levels,
                          max_rls=opt.optim.max_iter_rls,
                          max_gn=opt.optim.max_iter_gn,
                          penalty=opt.penalty.norm,
                          verbose=opt.verbose)
    printer.print_head()

    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    vx0 = vx = spatial.voxel_size(aff0)
    vol0 = vx0.prod()
    vol = vx.prod() / vol0
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    for level in range(opt.optim.nb_levels, 0, -1):
        printer.level = level

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            vol = vx.prod() / vol0
            maps, rls = resize(maps, rls, aff, shape)
            if opt.penalty.norm in ('tv', 'jtv'):
                sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double)

        # --- compute derivatives ---
        nb_prm = len(maps)
        nb_hes = nb_prm * (nb_prm + 1) // 2
        grad = torch.empty((nb_prm, ) + shape, **backend)
        hess = torch.empty((nb_hes, ) + shape, **backend)

        ll_rls = []
        ll_gn = []
        ll_max = float('inf')

        max_iter_rls = max(opt.optim.max_iter_rls // level, 1)
        for n_iter_rls in range(max_iter_rls):
            # --- Reweighted least-squares loop ---
            printer.rls = n_iter_rls

            # --- Gauss Newton loop ---
            for n_iter_gn in range(opt.optim.max_iter_gn):
                # start = time.time()
                printer.gn = n_iter_gn
                crit = 0
                grad.zero_()
                hess.zero_()
                # --- loop over contrasts ---
                for contrast, b1m, b1p in zip(data, receive, transmit):
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, maps, b1m, b1p, opt)

                    # increment
                    if hasattr(maps, 'mt') and not contrast.mt:
                        # we optimize for mt but this particular contrast
                        # has no information about mt so g1/h1 are smaller
                        # than grad/hess.
                        grad[:-1] += g1
                        hind = list(range(nb_prm - 1))
                        cnt = nb_prm
                        for i in range(nb_prm):
                            for j in range(i + 1, nb_prm):
                                if i != nb_prm - 1 and j != nb_prm - 1:
                                    hind.append(cnt)
                                cnt += 1
                        hess[hind] += h1
                        crit += crit1
                    else:
                        grad += g1
                        hess += h1
                        crit += crit1

                    del g1, h1, crit1
                # duration = time.time() - start
                # print('grad', duration)

                # start = time.time()
                reg = 0.
                if opt.penalty.norm:
                    g = spatial.regulariser(maps.volume,
                                            weights=rls,
                                            dim=3,
                                            voxel_size=vx,
                                            membrane=1,
                                            factor=lam * vol)
                    grad += g
                    reg = 0.5 * dot(maps.volume, g)
                    del g
                # duration = time.time() - start
                # print('reg', duration)

                # --- gauss-newton ---
                # start = time.time()
                grad = check_nans_(grad, warn='gradient')
                hess = check_nans_(hess, warn='hessian')
                if opt.penalty.norm:
                    # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8)  # 1e-5 1e-8
                    if opt.optim.solver == 'fmg':
                        deltas = spatial.solve_field_fmg(hess,
                                                         grad,
                                                         rls,
                                                         factor=lam * vol,
                                                         membrane=1,
                                                         voxel_size=vx,
                                                         nb_iter=2)
                    else:
                        deltas = spatial.solve_field(
                            hess,
                            grad,
                            rls,
                            factor=lam * vol,
                            membrane=1,
                            voxel_size=vx,
                            verbose=max(0, opt.verbose - 1),
                            optim='cg',
                            max_iter=opt.optim.max_iter_cg,
                            tolerance=opt.optim.tolerance_cg,
                            stop='diff')
                else:
                    # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4)
                    deltas = spatial.solve_field_closedform(hess, grad)
                deltas = check_nans_(deltas, warn='deltas')
                # duration = time.time() - start
                # print('solve', duration)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    map.volume.clamp_(-64, 64)  # avoid exp overflow
                    del delta
                del deltas

                # --- Compute gain ---
                ll = crit + reg + sumrls
                ll_max = max(ll_max, ll)
                ll_prev = ll_gn[-1] if ll_gn else ll_max
                gain = ll_prev - ll
                ll_gn.append(ll)
                printer.print_crit(crit, reg, sumrls, gain / ll_scl)
                if gain < opt.optim.tolerance * ll_scl:
                    # print('GN converged: ', ll_prev.item(), '->', ll.item())
                    break

            # --- Update RLS weights ---
            if opt.penalty.norm in ('tv', 'jtv'):
                rls, sumrls = spatial.membrane_weights(
                    maps.volume,
                    lam,
                    vx,
                    return_sum=True,
                    dim=3,
                    joint=opt.penalty.norm == 'jtv',
                    eps=core.constants.eps(rls.dtype))
                sumrls *= 0.5 * vol

                # --- Compute gain ---
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll = ll_gn[-1]
                ll_prev = ll_rls[-1] if ll_rls else ll_max
                ll_rls.append(ll)
                gain = ll_prev - ll
                if abs(gain) < opt.optim.tolerance * ll_scl:
                    # print(f'RLS converged ({gain:7.2g})')
                    break

    del grad
    if opt.uncertainty:
        uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt)
        maps.pd.uncertainty = uncertainty[0]
        maps.r1.uncertainty = uncertainty[1]
        maps.r2s.uncertainty = uncertainty[2]
        if hasattr(maps, 'mt'):
            maps.mt.uncertainty = uncertainty[3]

    # --- Prepare output ---
    return postproc(maps)
예제 #5
0
def zcorrect_square(x,
                    decay=None,
                    sigma=None,
                    lam=10,
                    max_iter=128,
                    tol=1e-6,
                    verbose=False):
    """Correct the z signal decay in a SPIM image.

    The signal is modelled as: f(z) = s(z) / (1 + b * z**2) + eps
    where z=0 is the top slice, s(z) is the theoretical signal if there
    was no absorption and b is the decay coefficient.

    Parameters
    ----------
    x : (..., nz) tensor
        SPIM image with the z dimension last and the z=0 plane first
    decay : float, optional
        Initial guess for decay parameter. Default: educated guess.
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation.
    max_iter : int, default=128
    tol : float, default=1e-6
    verbose : int or bool, default=False

    Returns
    -------
    y : tensor
        Corrected image
    decay : float
        Decay parameters

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    backend = utils.backend(x)
    shape = x.shape
    nz = shape[-1]
    x = x.reshape([-1, nz])
    b = decay

    # decay educated guess: closed form two values at z=1/3 and z=2/3
    z1 = nz // 3
    z2 = 2 * nz // 3
    x1 = x[:, z1].median()
    x2 = x[:, z2].median()
    z1 = float(z1)**2
    z2 = float(z2)**2
    b = b or (x2 - x1) / (x1 * z1 - x2 * z2)
    b = abs(b)
    y0 = x1 * (1 + b * z1)

    y0 = y0.item()
    b = b.item() if torch.is_tensor(b) else b

    # noise educated guess: assume SNR=5 at z=1/2
    sigma = sigma or (y0 / (1 + b * (nz / 2)**2)) / 5
    lam = lam**2 * sigma**2
    reg = lambda y: spatial.regulariser(y[:, None], membrane=lam, dim=1)[:, 0]
    solve = lambda h, g: spatial.solve_field_sym(
        h[:, None], g[:, None], membrane=lam, dim=1)[:, 0]

    print(y0, b, sigma, lam)

    # init
    z2 = torch.arange(nz, **backend).square_()
    logy = torch.full_like(x, y0).log_()
    logb = torch.as_tensor(b, **backend)
    y = logy.exp()
    b = logb.exp()
    ll0 = (y / (1 + b * z2) - x).square_().sum() + (logy * reg(logy)).sum()
    ll1 = ll0
    for it in range(max_iter):

        # exponentiate
        y = torch.exp(logy, out=y)
        fit = y / (1 + b * z2)
        res = fit - x

        # compute objective
        ll = res.square().sum() + (logy * reg(logy)).sum()
        gain = (ll1 - ll) / ll0
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end)
        if it > 0 and gain < tol:
            break
        ll1 = ll

        # update decay
        g = -(z2 * b * y) / (z2 * b + 1).square()
        h = b * y - z2 * b.square() * y
        h *= z2 / (z2 * b + 1).pow(3)
        h = h.abs_() * res.abs()
        h += g.square()
        g *= res

        g = g.sum()
        h = h.sum()
        logb -= g / h

        # update fit
        b = torch.exp(logb, out=b)
        fit = y / (1 + b * z2)
        res = fit - x

        # ll = (fit - x).square().sum() + 1e3 * (logy[1:] - logy[:-1]).sum().square()
        # gain = (ll1 - ll) / ll0
        # print(f'{it} | {ll.item()} | {gain.item()}', end='\n')

        # update y
        g = h = y / (z2 * b + 1)
        h = h.abs() * res.abs()
        h += g.square()
        g *= res
        g += reg(logy)
        logy -= solve(h, g)

    y = torch.exp(logy, out=y)
    y = y.reshape(shape)
    x = x * (1 + b * z2)
    x = x.reshape(shape)
    return y, b, x
예제 #6
0
def correct_smooth(x,
                   sigma=None,
                   lam=10,
                   gamma=10,
                   downsample=None,
                   max_iter=16,
                   max_rls=8,
                   tol=1e-6,
                   verbose=False,
                   device=None):
    """Correct the intensity non-uniformity in a SPIM image.

    The signal is modelled as: f = exp(s + b) + eps, with a penalty on
    the (Squared) gradients of s and on the (squared) curvature of b.

    Parameters
    ----------
    x : tensor
        SPIM image with the z dimension last and the z=0 plane first
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation on the signal.
    gamma : float, default=10
        Regularisation on the bias field.
    max_iter : int, default=16
        Maximum number of Newton iterations.
    max_rls : int, default=8
        Maximum number of reweighting iterations.
        If 1, this is effectively an l2 regularisation.
    tol : float, default=1e-6
        Tolerance for early stopping.
    verbose : int or bool, default=False
        Verbosity level
    device : torch.device, default=x.device
        Use this device during fitting.

    Returns
    -------
    y : tensor
        Fitted image
    bias : float
        Fitted bias
    x : float
        Corrected image

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    dim = x.dim()

    # downsampling
    if downsample:
        x0 = x
        downsample = py.make_list(downsample, dim)
        x = spatial.pool(dim, x, downsample)
    shape = x.shape
    x = x.to(device)

    # noise educated guess: assume SNR=5 at z=1/2
    center = tuple(slice(s // 3, 2 * s // 3) for s in shape)
    sigma = sigma or x[center].median() / 5
    lam = lam**2 * sigma**2
    gamma = gamma**2 * sigma**2
    regy = lambda y, w: spatial.regulariser(
        y[None], membrane=lam, dim=dim, weights=w)[0]
    regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0]
    solvey = lambda h, g, w: spatial.solve_field_sym(
        h[None], g[None], membrane=lam, dim=dim, weights=w)[0]
    solveb = lambda h, g: spatial.solve_field_sym(
        h[None], g[None], bending=gamma, dim=dim)[0]

    # init
    l1 = max_rls > 1
    if l1:
        w = torch.ones_like(x)[None]
        llw = w.sum()
        max_rls = 10
    else:
        w = None
        llw = 0
        max_rls = 1
    logb = torch.zeros_like(x)
    logy = x.clamp_min(1e-3).log_()
    y = logy.exp()
    b = logb.exp()
    fit = y * b
    res = fit - x
    llx = res.square().sum()
    lly = (regy(logy, w).mul_(logy)).sum()
    llb = (regb(logb).mul_(logb)).sum()
    ll0 = llx + lly + llb + llw
    ll1 = ll0

    for it_ls in range(max_rls):
        for it in range(max_iter):

            # update bias
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regb(logb)
            logb -= solveb(h, g)
            logb0 = logb.mean()
            logb -= logb0
            logy += logb0

            # update fit / ll
            llb = (regb(logb).mul_(logb)).sum()
            b = torch.exp(logb, out=b)
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x

            # update y
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regy(logy, w)
            logy -= solvey(h, g, w)

            # update fit / ll
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x
            lly = (regy(logy, w).mul_(logy)).sum()

            # compute objective
            llx = res.square().sum()
            ll = llx + lly + llb + llw
            gain = (ll1 - ll) / ll0
            ll1 = ll
            if verbose:
                end = '\n' if verbose > 1 else '\r'
                pre = f'{it_ls:3d} | ' if l1 else ''
                print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}',
                      end=end)
            if it > 0 and abs(gain) < tol:
                break

        if l1:
            w, llw = spatial.membrane_weights(logy[None],
                                              lam,
                                              dim=dim,
                                              return_sum=True)
            ll0 = ll
    if verbose:
        print('')

    if downsample:
        b = spatial.resize(logb.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        y = spatial.resize(logy.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        x = x0
    else:
        y = torch.exp(logy, out=y)
    x = x / b
    return y, b, x
예제 #7
0
def zcorrect_exp_const(x,
                       decay=None,
                       sigma=None,
                       lam=10,
                       mask=None,
                       max_iter=128,
                       tol=1e-6,
                       verbose=False,
                       snr=5):
    """Correct the z signal decay in a SPIM image.

    The signal is modelled as: f(z) = s * exp(-b * z) + eps
    where z=0 is (arbitrarily) the middle slice, s is the intercept
    and b is the decay coefficient.

    Parameters
    ----------
    x : (..., nz) tensor
        SPIM image with the z dimension last and the z=0 plane first
    decay : float, optional
        Initial guess for decay parameter. Default: educated guess.
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float or (float, float), default=10
        Regularisation.
    max_iter : int, default=128
    tol : float, default=1e-6
    verbose : int or bool, default=False

    Returns
    -------
    y : tensor
        Corrected image
    decay : float
        Decay parameters

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    backend = utils.backend(x)
    shape = x.shape
    dim = x.dim() - 1
    nz = shape[-1]
    b = decay

    x = utils.movedim(x, -1, 0).clone()
    if mask is None:
        mask = torch.isfinite(x) & (x > 0)
    else:
        mask = mask & (torch.isfinite(x) & (x > 0))
    x[~mask] = 0

    # decay educated guess: closed form from two values
    if b is None:
        z1 = 2 * nz // 5
        z2 = 3 * nz // 5
        x1 = x[z1]
        x1 = x1[x1 > 0].median()
        x2 = x[z2]
        x2 = x2[x2 > 0].median()
        z1 = float(z1)
        z2 = float(z2)
        b = (x2.log() - x1.log()) / (z1 - z2)
    y = x[(nz - 1) // 2]
    y = y[y > 0].median().log()

    b = b.item() if torch.is_tensor(b) else b
    y = y.item()
    print(f'init: y = {y}, b = {b}')

    # noise educated guess: assume SNR=5 at z=1/2
    sigma = sigma or (y / snr)
    lam_y, lam_b = py.make_list(lam, 2)
    lam_y = lam_y**2 * sigma**2
    lam_b = lam_b**2 * sigma**2
    reg = lambda t: spatial.regulariser(
        t, membrane=1, dim=dim, factor=(lam_y, lam_b))
    solve = lambda h, g: spatial.solve_field_fmg(
        h, g, membrane=1, dim=dim, factor=(lam_y, lam_b))

    # init
    z = torch.arange(nz, **backend) - (nz - 1) / 2
    z = utils.unsqueeze(z, -1, dim)
    theta = z.new_empty([2, *x.shape[1:]], **backend)
    logy = theta[0].fill_(y)
    b = theta[1].fill_(b)
    y = logy.exp()
    ll0 = (mask * y *
           (-b * z).exp_() - x).square_().sum() + (theta * reg(theta)).sum()
    ll1 = ll0

    g = torch.zeros_like(theta)
    h = theta.new_zeros([3, *theta.shape[1:]])
    for it in range(max_iter):

        # exponentiate
        y = torch.exp(logy, out=y)
        fit = (b * z).neg_().exp_().mul_(y).mul_(mask)
        res = fit - x

        # compute objective
        reg_theta = reg(theta)
        ll = res.square().sum() + (theta * reg_theta).sum()
        gain = (ll1 - ll) / ll0
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end)
        if it > 0 and gain < tol:
            break
        ll1 = ll

        g[0] = (fit * res).sum(0)
        g[1] = -(fit * res * z).sum(0)
        h[0] = (fit * (fit + res.abs())).sum(0)
        h[1] = (fit * (fit + res.abs()) * (z * z)).sum(0)
        h[2] = -(z * fit * fit).sum(0)

        g += reg_theta
        theta -= solve(h, g)

    y = torch.exp(logy, out=y)
    x = x * (b * z).exp_()
    x = utils.movedim(x, 0, -1)
    x = x.reshape(shape)
    return y, b, x
예제 #8
0
파일: phase.py 프로젝트: balbasty/nitorch
def phase_fit(magnitude, phase, lam=(0, 1e1), penalty=('membrane', 'bending')):
    """Fit a complex image using a decreasing phase regularization

    Parameters
    ----------
    magnitude : tensor
    phase : tensor
    lam : (float, float)
    penalty : (str, str)

    Returns
    -------
    magnitude : tensor
    phase : tensor

    """

    # estimate noise precision
    sd = estimate_noise(magnitude)[0]['sd']
    prec = 1 / (sd * sd)

    # initialize fit
    fit = magnitude.new_empty([2, *magnitude.shape])
    fit[0] = magnitude
    fit[0].clamp_min_(1e-8).log_()
    fit[1] = mean_phase(phase, magnitude)

    # allocate placeholders
    g = magnitude.new_empty([2, *magnitude.shape])
    h = magnitude.new_empty([2, *magnitude.shape])
    n = magnitude.numel()

    # prepare regularizer options
    prm = dict(
        membrane=[lam[0] * int(penalty[0] == 'membrane'),
                  lam[1] * int(penalty[1] == 'membrane')],
        bending=[lam[0] * int(penalty[0] == 'bending'),
                 lam[1] * int(penalty[1] == 'bending')],
        bound='dct2')

    lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1])
    ll0 = lr0 = factor = float('inf')
    for n_iter in range(20):

        # decrease regularization
        factor, factor_prev = 1 + 10 ** (5 - n_iter), factor
        factor_ratio = factor / factor_prev if n_iter else float('inf')
        prm['membrane'][-1] = lam0['membrane'] * factor
        prm['bending'][-1] = lam0['bending'] * factor

        # compute derivatives
        ll, g, h = derivatives(magnitude, phase, fit[0], fit[1], g, h)
        ll *= prec
        g *= prec
        h *= prec

        # compute regularization
        reg = spatial.regulariser(fit, **prm)
        lr = 0.5 * dot(fit, reg)
        g += reg

        # Gauss-Newton step
        fit -= spatial.solve_field_fmg(h, g, **prm)

        # Compute progress
        l0 = ll0 + factor_ratio * lr0
        l = ll + lr
        gain = l0 - l
        print(f'{n_iter:3d} | {ll/n:12.6g} + {lr/n:6.3g} = {l/n:12.6g} | gain = {gain/n:6.3}')
        if abs(gain) < n * 1e-4:
            break

        ll0, lr0 = ll, lr
        # plot_fit(magnitude, phase, fit)

    return fit[0].exp_(), fit[1]
예제 #9
0
def run_epic_noreverse(echoes,
                       voxshift=None,
                       lam=1,
                       sigma=1,
                       max_iter=(10, 32),
                       tol=1e-5,
                       verbose=False):
    """Run EPIC on pre-loaded tensors (no reverse gradients or synthesis)

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # initialize denoised echoes
    fit = echoes.clone()
    fwd_fit = torch.empty_like(fit)

    # prepare voxel shift maps
    do_voxshift = voxshift is not None
    if do_voxshift:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if do_voxshift:
        one = torch.empty_like(voxshift)[None]
        h = torch.empty_like(echoes)
        h[0::2] = push1d(pull1d(one, voxshift), voxshift)
        h[1::2] = push1d(pull1d(one, ivoxshift), ivoxshift)
        del one
    else:
        h = fit.new_ones([1] * (nd + 1))
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        ll = ssq(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        g = fwd_fit.mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
예제 #10
0
파일: topup.py 프로젝트: balbasty/nitorch
 def regulariser(v):
     return spatial.regulariser(v, dim=ndim, **{penalty: 1},
                                factor=lam, voxel_size=vx, bound=BND)
예제 #11
0
def run_epic(echoes,
             reverse_echoes=None,
             voxshift=None,
             extrapolate=True,
             lam=1,
             sigma=1,
             max_iter=(10, 32),
             tol=1e-5,
             verbose=False):
    """Run EPIC on pre-loaded tensors.

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    reverse_echoes : (N, *spatial)
        Echoes acquired with reverse bipolar readout. Else: synthesized.
    voxshift : (*spatial) tensor
        Voxel shift map used to deform towards even (0, 2, ...) echoes.
        Its inverse is used to deform towards odd (1, 3, ...) echoes.
    extrapolate : bool
        Extrapolate first/last echo when reverse_echoes is None.
        Otherwise, only use interpolated echoes.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    if reverse_echoes is False:
        return run_epic_noreverse(echoes, voxshift, lam, sigma, max_iter, tol,
                                  verbose)

    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # synthesize echoes
    synth = not torch.is_tensor(reverse_echoes)
    if synth:
        neg = synthesize_neg(echoes[0::2])
        pos = synthesize_pos(echoes[1::2])
        reverse_echoes = torch.stack([x for y in zip(pos, neg) for x in y])
        del pos, neg
    else:
        extrapolate = True

    # initialize denoised echoes
    fit = (echoes + reverse_echoes).div_(2)
    if not extrapolate:
        fit[0] = echoes[0]
        fit[-1] = echoes[-1]
    fwd_fit = torch.zeros_like(fit)
    bwd_fit = torch.zeros_like(fit)

    # prepare voxel shift maps
    if voxshift is not None:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if voxshift is not None:
        one = torch.ones_like(voxshift)[None]
        if extrapolate:
            h = push1d(pull1d(one, voxshift), voxshift)
            h += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h = torch.zeros_like(fit)
            h[:-1] += push1d(pull1d(one, voxshift), voxshift)
            h[1:] += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
        del one
        weight_(h)
    else:
        h = fit.new_ones([ne] + [1] * nd)
        if extrapolate:
            h *= 2
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h[1:-1] *= 2
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
    weight_(h)
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood (forward)
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        halfweight_(fwd_fit)
        ll = ssq(fwd_fit)
        halfweight_(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        # gradient of likelihood (reversed)
        pull_backward(fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)
        if extrapolate:
            bwd_fit.sub_(reverse_echoes)
        else:
            bwd_fit[1:-1].sub_(reverse_echoes[1:-1])
        halfweight_(bwd_fit)
        ll += ssq(bwd_fit)
        halfweight_(bwd_fit)
        push_backward(bwd_fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)

        g = fwd_fit.add_(bwd_fit).mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
예제 #12
0
파일: tv.py 프로젝트: balbasty/nitorch
def denoise(image=None,
            lam=1,
            sigma=1,
            max_iter=64,
            sub_iter=32,
            optim='cg',
            tol=1e-5,
            sub_tol=1e-5,
            plot=False,
            jtv=True,
            dim=None,
            **prm):
    """Denoise an image using a (joint) total variation prior.

    This implementation uses a reweighted least squares approach.

    Parameters
    ----------
    image : (..., K, *spatial) tensor, Image to denoise with K channels
    lam : [list of] float, default=1, Regularisation factor
    max_iter : int, default=64, Number of RLS iterations
    sub_iter : int, default=32, Number of relaxation/cg iterations
    optim : {'cg', 'relax', 'fmg+cg', 'fmg+relax'}, default='cg'
    plot : bool, default=False
    jtv : bool, default=True, Joint TV across channels ($\ell_{1,2}$)
    dim : int, default=image.dim()-1

    Returns
    -------
    denoised : (..., K, *spatial) tensor, Denoised image

    """

    if image is None:
        torch.random.manual_seed(1234)
        image = phantoms.augment(phantoms.circle(), fwhm=0)[None]

    image = torch.as_tensor(image)
    dim = dim or (image.dim() - 1)
    nvox = image.shape[-dim:].numel()

    # regularization (prior)
    lam = make_list(lam, image.shape[-dim - 1])
    if jtv:
        lam = [l / image.shape[-dim - 1] for l in lam]
    prm['membrane'] = 1
    prm['factor'] = lam

    # noise variance (likelihood)
    sigma = make_vector(sigma,
                        image.shape[-dim - 1],
                        dtype=image.dtype,
                        device=image.device)
    isigma2 = 1 / (sigma * sigma)

    # solver
    optim, lr = make_list(optim, 2, default=1)
    do_fmg, optim = ('fmg' in optim), ('relax' if 'relax' in optim else 'cg')
    if do_fmg:

        def solve(h, g, w, optim):
            return spatial.solve_field_fmg(h,
                                           g,
                                           dim=dim,
                                           weights=w,
                                           **prm,
                                           optim=optim)
    else:

        def solve(h, g, w, optim):
            return spatial.solve_field(h,
                                       g,
                                       dim=dim,
                                       weights=w,
                                       **prm,
                                       optim=optim,
                                       max_iter=sub_iter,
                                       tolerance=sub_tol)

    # initialize
    denoised = image.clone()

    l_prev = None
    l_max = None
    for n_iter in range(1, max_iter + 1):

        # update weight map / tv loss
        weights, lw = spatial.membrane_weights(denoised,
                                               dim=dim,
                                               return_sum=True,
                                               joint=jtv,
                                               factor=lam)

        # gradient / hessian of least square problem
        ll, g, h = losses.mse(denoised, image, dim=dim, lam=isigma2)
        ll.mul_(nvox), g.mul_(nvox), h.mul_(nvox)
        g += spatial.regulariser(denoised, dim=dim, weights=weights, **prm)

        # solve least square problem
        optim0 = 'cg' if n_iter < 10 else optim  # it seems it helps
        g = solve(h, g, weights, optim0)
        if lr != 1:
            g.mul_(lr)
        denoised -= g

        ll = ll.item() / image.numel()
        lw = lw.item() / image.numel()
        l = ll + lw

        # print objective
        if l_prev is None:
            l_prev = l_max = l
            gain = float('inf')
            print(f'{n_iter:03d} | {ll:12.6g} + {lw:12.6g} = {l:12.6g}',
                  end='\r')
        else:
            gain = (l_prev - l) / max(abs(l_max - l), 1e-8)
            l_prev = l
            print(
                f'{n_iter:03d} | {ll:12.6g} + {lw:12.6g} = {l:12.6g} '
                f'| gain = {gain:12.6g}',
                end='\r')

        if plot and (n_iter - 1) % (max_iter // 10 + 1) == 0:
            import matplotlib.pyplot as plt
            img = image[0, image.shape[1] // 2] if dim == 3 else image[0]
            den = denoised[0,
                           denoised.shape[1] // 2] if dim == 3 else denoised[0]
            wgt = weights[0, weights.shape[1] // 2] if dim == 3 else weights[0]
            plt.subplot(1, 3, 1)
            plt.imshow(img.cpu())
            plt.subplot(1, 3, 2)
            plt.imshow(den.cpu())
            plt.subplot(1, 3, 3)
            plt.imshow(wgt.reciprocal().cpu())
            plt.colorbar()
            plt.show()

        if gain < tol:
            break
    print('')

    return denoised
예제 #13
0
파일: _nonlin.py 프로젝트: balbasty/nitorch
def meetup(data, dist=None, opt=None):
    """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion]
        B0-induced distortion fields

    """
    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit, vreg, reg = float('inf'), 0, 0
    ll_rls = crit + vreg + reg
    sumrls_prev = sumrls
    rls_changed = False
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    armijo_prev = 1
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer_prm = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve_prm = lambda H, g: spatial.solve_field(
                H,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Initial update of parameter maps
            # ----------------------------------------------------------
            crit_pre_prm = None
            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i]
                    grad[gind] += g1
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer_prm(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                if n_iter_prm == 1:
                    crit_pre_prm = crit

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.

                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | '
                            f'{crit_prev:12.6g} + {reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve_prm(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                dd = regularizer_prm(deltas)
                dv = dot(dd, maps.volume)
                dd = dot(dd, deltas)
                delta_reg = 0.5 * (dd - 2 * dv)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- track parameter map improvement --------------
                gain = (crit_prev + reg_prev) - (crit + reg)
                if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # ----------------------------------------------------------
            #    Distortion update with line search
            # ----------------------------------------------------------
            max_iter_dist = opt.optim.max_iter_dist
            for n_iter_dist in range(1, max_iter_dist + 1):

                crit = 0
                vreg = 0
                new_crit = 0
                new_vreg = 0

                deltas, dd, dv = [], 0, 0
                # --- loop over contrasts ------------------------------
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):

                    # --- helpers --------------------------------------
                    vxr = distortion.voxel_size[contrast.readout]
                    lam_dist = dict(opt.distortion.factor)
                    lam_dist['factor'] *= vxr**2
                    regularizer_dist = lambda x: spatial.regulariser(
                        x[None],
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)[0]
                    solve_dist = lambda h, g: spatial.solve_field_fmg(
                        h,
                        g,
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)

                    # --- likelihood -----------------------------------
                    crit1, g, h = derivatives_distortion(
                        contrast, distortion, intercept, maps.decay, opt)
                    crit += crit1

                    # --- regularization -------------------------------
                    vol = distortion.volume
                    g1 = regularizer_dist(vol)
                    vreg1 = 0.5 * vol.flatten().dot(g1.flatten())
                    vreg += vreg1
                    g += g1
                    del g1

                    # --- gauss-newton ---------------------------------
                    h = check_nans_(h, warn='hessian (distortion)')
                    h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0]
                    delta = solve_dist(h, g)
                    delta = check_nans_(delta, warn='delta (distortion)')
                    deltas.append(delta)
                    del g, h

                    deltas.append(delta)
                    dd1 = regularizer_dist(delta)
                    dv += dot(dd1, vol)
                    dd1 = dot(dd1, delta)
                    dd += dd1
                    del delta, vol

                # --- track parameters improvement ---------------------
                if opt.verbose and n_iter_dist == 1:
                    gain = crit_pre_prm - (crit + delta_reg)
                    evol = '<=' if gain > 0 else '>'
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | '
                            f'{crit:12.6g} + {reg + delta_reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | '
                            f'gain = {gain / ll_scl:7.2g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- line search ----------------------------------
                reg = reg + delta_reg
                new_crit, new_reg = crit, reg
                armijo, armijo_prev, ok = armijo_prev, 0, False
                maps0 = maps
                for n_ls in range(1, 12 + 1):
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev))
                    armijo_prev = armijo
                    new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv)

                    maps = maps0.deepcopy()
                    max_iter_prm = opt.optim.max_iter_prm
                    for n_iter_prm in range(1, max_iter_prm + 1):

                        crit_prev = new_crit
                        reg_prev = new_reg

                        # --- loop over contrasts ------------------
                        new_crit = 0
                        grad.zero_()
                        hess.zero_()
                        for i, (contrast, intercept, distortion) in enumerate(
                                zip(data, maps.intercepts, dist)):
                            # compute gradient
                            new_crit1, g1, h1 = derivatives_parameters(
                                contrast, distortion, intercept, maps.decay,
                                opt)
                            # increment
                            gind, hind = [i, -1
                                          ], [i,
                                              len(grad) - 1,
                                              len(grad) + i]
                            grad[gind] += g1
                            hess[hind] += h1
                            new_crit += new_crit1
                        del g1, h1

                        # --- regularization -----------------------
                        new_reg = 0.
                        if opt.regularization.norm:
                            g1 = regularizer_prm(maps.volume)
                            new_reg = 0.5 * dot(maps.volume, g1)
                            grad += g1
                            del g1

                        new_gain = (crit_prev + reg_prev) - (new_crit +
                                                             new_reg)
                        if new_gain < opt.optim.tolerance * ll_scl:
                            break

                        # --- gauss-newton -------------------------
                        # Computing the GN step involves solving H\g
                        hess = check_nans_(hess, warn='hessian')
                        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                        delta = solve_prm(hess, grad)
                        delta = check_nans_(delta, warn='delta')

                        dd = regularizer_prm(delta)
                        dv = dot(dd, maps.volume)
                        dd = dot(dd, delta)
                        delta_reg = 0.5 * (dd - 2 * dv)

                        for map, delta1 in zip(maps, delta):
                            map.volume -= delta1
                            if map.min is not None or map.max is not None:
                                map.volume.clamp_(map.min, map.max)
                        del delta

                    if new_crit + new_reg + new_vreg <= crit + reg:
                        ok = True
                        break
                    else:
                        armijo = armijo / 2

                if not ok:
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.add_(delta1, alpha=armijo_prev)
                    armijo_prev = 1
                    maps = maps0
                    new_crit = crit
                    new_vreg = 0
                    del delta
                else:
                    armijo_prev *= 1.5
                new_vreg = vreg + new_vreg

                # --- track distortion improvement ---------------------
                if not ok:
                    evol = '== (x)'
                elif new_crit + new_vreg <= crit + vreg:
                    evol = f'<= ({n_ls:2d})'
                else:
                    evol = f'> ({n_ls:2d})'
                gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg)
                crit, reg, reg = new_crit, new_vreg, new_reg
                if opt.verbose:
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | '
                        f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if not ok:
                    break

                if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = crit + reg + vreg
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'RLS converged ({gain / ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
예제 #14
0
파일: _nonlin.py 프로젝트: balbasty/nitorch
def nonlin(data, dist=None, opt=None):
    """Fit the ESTATICS model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion], if opt.distortion.enable
        B0-induced distortion fields

    """
    if opt.distortion.enable:
        return meetup(data, dist, opt)

    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit = float('inf')
    vreg = 0
    reg = 0
    sumrls_prev = sumrls
    rls_changed = False
    ll_gn = ll_rls = float('inf')
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        #   > max_iter_rls == 1 if regularization is not (J)TV
        max_iter_rls = opt.optim.max_iter_rls
        if level != 1:
            max_iter_rls = 1
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve = lambda h, g: spatial.solve_field(
                h,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Update parameter maps
            # ----------------------------------------------------------

            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind = [i, -1]
                    grad[gind] += g1
                    hind = [i, len(grad) - 1, len(grad) + i]
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.
                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {"---":3s} | {"rls":4s} | '
                        f'{crit_prev:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    if opt.distortion.enable:
                        pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                for map, delta in zip(maps, deltas):
                    map.volume -= delta

                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- compute GN gain ----------------------------------
                ll_gn_prev = ll_gn
                ll_gn = crit + reg + vreg + sumrls
                gain = ll_gn_prev - ll_gn
                if opt.verbose:
                    pstr = f'{n_iter_rls:3d} | {n_iter_prm:3d} | '
                    if opt.distortion.enable:
                        pstr += f'{"----":4s} | '
                        pstr += f'{"-"*72:72s} | '
                    else:
                        pstr += f'{"prm":4s} | '
                        pstr += f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} '
                        pstr += f'= {ll_gn:12.6g} | '
                    pstr += f'gain = {gain/ll_scl:7.2g}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                reg = 0.5 * dot(maps.volume, regularizer(maps.volume))
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = ll_gn
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'Converged ({gain/ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
예제 #15
0
def sin_b1(x,
           fa,
           lam=(0, 1e3),
           penalty=('m', 'b'),
           chi=True,
           pd=None,
           b1=None):
    """Estimate B1+ from Variable flip angle data with long TR

    Parameters
    ----------
    x : (C, *spatial) tensor
        Input flash images
    fa : (C,) sequence[float]
        Flip angle (in deg)
    lam : (float, float), default=(0, 1e4)
        Regularization value for the T2*w Signal, T1  map and B1 map.
    penalty : 3x {'membrane', 'bending'}, default=('m', 'b')
        Regularization type for the T2*w Signal, T1  map and B1 map.

    Returns
    -------
    s : (*spatial) tensor
        T2*-weighted PD map
    b1 : (*spatial) tenmsor
        B1+ map

    """

    fa = utils.make_vector(fa, len(x), dtype=torch.double)
    fa = fa.mul_(pymath.pi / 180).tolist()
    lam = py.make_list(lam, 2)
    penalty = py.make_list(penalty, 2)

    sd, df, mu = 0, 0, 0
    for x1 in x:
        bg, fg = estimate_noise(x1, chi=True)
        sd += bg['sd'].log()
        df += bg['dof'].log()
        mu += fg['mean'].log()
    sd = (sd / len(x)).exp()
    df = (df / len(x)).exp()
    mu = (mu / len(x)).exp()
    prec = 1 / (sd * sd)
    if not chi:
        df = 1

    # mask low SNR voxels
    # x = x * (x > 5 * sd)

    shape = x.shape[1:]
    theta = x.new_empty([2, *shape])
    theta[0] = mu.log() if pd is None else pd.log()
    theta[1] = 0 if b1 is None else b1.log()
    n = (x != 0).sum()

    g = torch.zeros_like(theta)
    h = theta.new_zeros([3, *theta.shape[1:]])
    g1 = torch.zeros_like(theta)
    h1 = theta.new_zeros([3, *theta.shape[1:]])

    prm = dict(
        membrane=[
            lam[0] if penalty[0][0] == 'm' else 0,
            lam[1] if penalty[1][0] == 'm' else 0
        ],
        bending=[
            lam[0] if penalty[0][0] == 'b' else 0,
            lam[1] if penalty[1][0] == 'b' else 0
        ],
    )

    lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1])
    ll0 = lr0 = factor = float('inf')
    for n_iter in range(32):

        if n_iter == 1:
            ll0 = lr0 = float('inf')

        # decrease regularization
        factor, factor_prev = 1 + 10**(5 - n_iter), factor
        factor_ratio = factor / factor_prev if n_iter else float('inf')
        prm['membrane'][-1] = lam0['membrane'] * factor
        prm['bending'][-1] = lam0['bending'] * factor

        # derivatives of likelihood term
        df1 = 1 if n_iter == 0 else df
        ll, g, h = sin_full_derivatives(x, fa, theta[0], theta[1], prec, df1,
                                        g, h, g1, h1)

        # derivatives of regularization term
        reg = spatial.regulariser(theta, **prm)
        g += reg
        lr = 0.5 * dot(theta, reg)

        l, l0 = ll + lr, ll0 + factor_ratio * lr0
        gain = l0 - l
        print(
            f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}'
        )

        # Gauss-Newton update
        h[:2] += 1e-8 * h[:2].abs().max(0).values
        h[:2] += 1e-5
        delta = spatial.solve_field_fmg(h, g, **prm)

        mx = delta.abs().max()
        if mx > 64:
            delta *= 64 / mx

        # theta -= delta
        # ll0, lr0 = ll, lr

        # line search
        dd = spatial.regulariser(delta, **prm)
        dt = dot(dd, theta)
        dd = dot(dd, delta)
        success = False
        armijo = 1
        theta0 = theta
        ll0, lr0 = ll, lr
        for n_ls in range(12):
            theta = theta0.sub(delta, alpha=armijo)
            ll = sin_nll(x, fa, theta[0], theta[1], prec, df1)
            lr = 0.5 * armijo * (armijo * dd - 2 * dt)
            if ll + lr < ll0:  # and theta[1].max() < 0.69:
                print(n_ls, 'success', ((ll + lr) / n).item(),
                      (ll0 / n).item())
                success = True
                break
            print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item())
            armijo /= 2
        if not success and n_iter > 5:
            theta = theta0
            break

        import matplotlib.pyplot as plt

        # plt.subplot(1, 2, 1)
        # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp())
        # plt.colorbar()
        # plt.subplot(1, 2, 2)
        # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp())
        # plt.colorbar()
        # plt.show()

        plt.rcParams["figure.figsize"] = (4, len(x))
        vmin, vmax = 0, 2 * mu
        y = sin_signals(fa, *theta)
        ex = 3
        for i in range(len(x)):
            plt.subplot(len(x) + ex, 4, 4 * i + 1)
            plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 2)
            plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 3)
            plt.imshow(x[i, ..., x.shape[-1] // 2] -
                       y[i, ..., y.shape[1] // 2],
                       cmap=plt.get_cmap('coolwarm'))
            plt.axis('off')
            plt.colorbar()
            plt.subplot(len(x) + ex, 4, 4 * i + 4)
            plt.imshow(
                (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi)
            plt.axis('off')
            plt.colorbar()
        all_fa = torch.linspace(0, 2 * pymath.pi, 512)
        loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2),
               (2 * theta.shape[1] // 3, theta.shape[2] // 2,
                theta.shape[3] // 2),
               (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)]
        for j, (nx, ny, nz) in enumerate(loc):
            plt.subplot(len(x) + ex, 1, len(x) + j + 1)
            plt.plot(all_fa, sin_signals(all_fa, *theta[:, nx, ny, nz]))
            plt.scatter(fa, x[:, nx, ny, nz])
        plt.show()

        # if gain < 1e-4 * n:
        #     break

    theta.exp_()
    return theta[0], theta[1]
예제 #16
0
def flash_b1(x,
             fa,
             tr,
             lam=(0, 0, 0),
             penalty=('m', 'm', 'm'),
             chi=True,
             pd=None,
             r1=None,
             b1=None):
    """Estimate B1+ from Variable flip angle data

    Parameters
    ----------
    x : (C, *spatial) tensor
        Input flash images
    fa : (C,) sequence[float]
        Flip angle (in deg)
    tr : float or (C,) sequence[float]
        Repetition time (in sec)
    lam : (float, float, float), default=(0, 0, 10)
        Regularization value for the T2*w Signal, T1  map and B1 map.
    penalty : 3x {'membrane', 'bending'}, default=('m', 'm', 'b')
        Regularization type for the T2*w Signal, T1  map and B1 map.

    Returns
    -------
    s : (*spatial) tensor
        T2*-weighted PD map
    r1 : (*spatial) tensor
        R1 map
    b1 : (*spatial) tenmsor
        B1+ map

    """

    tr = py.make_list(tr, len(x))
    fa = utils.make_vector(fa, len(x), dtype=torch.double)
    fa = fa.mul_(pymath.pi / 180).tolist()
    lam = py.make_list(lam, 3)
    penalty = py.make_list(penalty, 3)

    sd, df, mu = 0, 0, 0
    for x1 in x:
        bg, fg = estimate_noise(x1, chi=True)
        sd += bg['sd'].log()
        df += bg['dof'].log()
        mu += fg['mean'].log()
    sd = (sd / len(x)).exp()
    df = (df / len(x)).exp()
    mu = (mu / len(x)).exp()
    prec = 1 / (sd * sd)
    if not chi:
        df = 1

    shape = x.shape[1:]
    theta = x.new_empty([3, *shape])
    theta[0] = mu.log() if pd is None else pd.log()
    theta[1] = 1 if r1 is None else r1.log()
    theta[2] = 0 if b1 is None else b1.log()
    n = (x != 0).sum()

    g = torch.zeros_like(theta)
    h = theta.new_zeros([6, *theta.shape[1:]])
    g1 = torch.zeros_like(theta)
    h1 = theta.new_zeros([6, *theta.shape[1:]])

    prm = dict(
        membrane=(lam[0] if penalty[0][0] == 'm' else 0,
                  lam[1] if penalty[1][0] == 'm' else 0,
                  lam[2] if penalty[2][0] == 'm' else 0),
        bending=(lam[0] if penalty[0][0] == 'b' else 0,
                 lam[1] if penalty[1][0] == 'b' else 0,
                 lam[2] if penalty[2][0] == 'b' else 0),
    )

    ll0 = lr0 = float('inf')
    iter_start = 3
    for n_iter in range(32):

        if df > 1 and n_iter == iter_start:
            ll0 = lr0 = float('inf')

        # derivatives of likelihood term
        df1 = 1 if n_iter < iter_start else df
        ll, g, h = derivatives(x, fa, tr, theta[0], theta[1], theta[2], prec,
                               df1, g, h, g1, h1)

        # derivatives of regularization term
        reg = spatial.regulariser(theta, **prm, absolute=1e-10)
        g += reg
        lr = 0.5 * dot(theta, reg)

        l, l0 = ll + lr, ll0 + lr0
        gain = l0 - l
        print(
            f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}'
        )

        # Gauss-Newton update
        h[:3] += 1e-8 * h[:3].abs().max(0).values
        h[:3] += 1e-5
        if n_iter % 2:
            delta = torch.zeros_like(theta)
            prm1 = dict(membrane=prm['membrane'][:2],
                        bending=prm['bending'][:2])
            hh = torch.stack([h[0], h[1], h[3]])
            delta[:2] = spatial.solve_field(hh,
                                            g[:2],
                                            **prm1,
                                            max_iter=16,
                                            absolute=1e-10)
            del hh
        else:
            delta = torch.zeros_like(theta)
            prm1 = dict(membrane=prm['membrane'][-1],
                        bending=prm['bending'][-1])
            delta[-1:] = spatial.solve_field(h[2:3],
                                             g[2:3],
                                             **prm1,
                                             max_iter=16,
                                             absolute=1e-10)
        # delta = spatial.solve_field(h, g, **prm, max_iter=16, absolute=1e-10)
        # theta -= spatial.solve_field_fmg(h, g, **prm)

        # line search
        dd = spatial.regulariser(delta, **prm, absolute=1e-10)
        dt = dot(dd, theta)
        dd = dot(dd, delta)
        success = False
        armijo = 1
        theta0 = theta
        ll0, lr0 = ll, lr
        for n_ls in range(12):
            theta = theta0.sub(delta, alpha=armijo)
            ll = nll(x, fa, tr, *theta, prec, df1)
            lr = 0.5 * armijo * (armijo * dd - 2 * dt)
            if ll + lr < ll0:  # and theta[1].max() < 0.69:
                print(n_ls, 'success', ((ll + lr) / n).item(),
                      (ll0 / n).item())
                success = True
                break
            print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item())
            armijo /= 2
        if not success and n_iter > 5:
            theta = theta0
            break

        # delta = spatial.solve_field_fmg(h, g, **prm)
        #
        # # line search
        # dd = spatial.regulariser(delta, **prm)
        # dt = dot(dd, theta)
        # dd = dot(dd, delta)
        # success = False
        # armijo = 1
        # theta0 = theta
        # ll0, lr0 = ll, lr
        # for n_ls in range(12):
        #     theta = theta0.sub(delta, alpha=armijo)
        #     ll = nll(x, fa, tr, theta[0], theta[1], theta[2]) * prec
        #     lr = 0.5 * armijo * (armijo * dd - 2 * dt)
        #     if ll + lr < ll0:
        #         print('success', n_ls)
        #         success = True
        #         break
        #     armijo /= 2
        # if not success:
        #     theta = theta0
        #     break

        import matplotlib.pyplot as plt
        # plt.subplot(2, 2, 1)
        # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('PD * exp(-TE * R2*)')
        # plt.subplot(2, 2, 2)
        # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('R1')
        # plt.subplot(2, 2, 3)
        # plt.imshow(theta[2, :, :, theta.shape[-1]//2].exp())
        # plt.axis('off')
        # plt.colorbar()
        # plt.title('B1+')
        # plt.show()

        vmin, vmax = 0, 2 * mu
        y = flash_signals(fa, tr, *theta)
        ex = 3
        plt.rcParams["figure.figsize"] = (4, len(x) + ex)
        for i in range(len(x)):
            plt.subplot(len(x) + ex, 4, 4 * i + 1)
            plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 2)
            plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 3)
            plt.imshow(x[i, ..., x.shape[-1] // 2] -
                       y[i, ..., y.shape[1] // 2],
                       cmap=plt.get_cmap('coolwarm'))
            plt.axis('off')
            plt.colorbar()
            plt.subplot(len(x) + ex, 4, 4 * i + 4)
            plt.imshow(
                (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi)
            plt.axis('off')
            plt.colorbar()
        all_fa = torch.linspace(0, 2 * pymath.pi, 512)
        loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2),
               (2 * theta.shape[1] // 3, theta.shape[2] // 2,
                theta.shape[3] // 2),
               (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)]
        for j, (nx, ny, nz) in enumerate(loc):
            plt.subplot(len(x) + ex, 1, len(x) + j + 1)
            plt.plot(
                all_fa,
                flash_signals(all_fa, tr[:1] * 512, *theta[:, nx, ny, nz]))
            plt.scatter(fa, x[:, nx, ny, nz])
        plt.show()

        if gain < 1e-4 * n:
            break

    theta.exp_()
    return theta[0], theta[1], theta[2]