Exemplo n.º 1
0
def _chi_fit_tv(dat, sigma=None, df=None, lam=10, max_iter=50, tol=1e-5):

    noise, tissue = estimate_noise(dat, chi=True)
    mu = tissue['mean']
    sigma = sigma or noise['sd']
    df = df or noise['dof']
    print(f'sigma = {sigma}, dof = {df}, mu = {mu}')
    isigma2 = 1 / (sigma * sigma)
    lam = lam / mu

    msk = torch.isfinite(dat)
    if msk.any():
        dat = dat.masked_fill(msk.bitwise_not_(), 0)
    fit = chi_bias_correction(dat, sigma, df)[0]
    msk = dat > 0
    n = msk.sum()

    w = torch.ones_like(dat)
    h = w.new_full([1] * dat.dim(), isigma2)[None]

    ll_prev = float('inf')
    for n_iter in range(max_iter):

        w, tv = spatial.membrane_weights(fit[None],
                                         factor=lam,
                                         return_sum=True)

        l, delta = nll_chi(dat, fit, msk, isigma2, df)
        delta += spatial.regulariser(fit[None], membrane=lam, weights=w)[0]
        delta = spatial.solve_field(h, delta[None], membrane=lam, weights=w)[0]
        fit.sub_(delta).clamp_min_(1e-8 * sigma)

        ll = l + tv
        gain, ll_prev = ll_prev - ll, ll
        print(f'{n_iter:3d} | {l/n:12.6g} + {tv/n:12.6g} = {ll/n:12.6g} '
              f'| gain = {gain/n:6.3}')
        if abs(gain) < tol * n:
            break

    return fit, sigma, df
Exemplo n.º 2
0
def do_inpaint(dat,
               voxel_size=1,
               max_iter_rls=50,
               max_iter_cg=32,
               tol_rls=1e-5,
               tol_cg=1e-5,
               verbose=1):
    """Perform inpainting

    Parameters
    ----------
    dat : (channels, *spatial) tensor
        Tensor with missing data marked as NaN

    Returns
    -------
    dat : (channels, *spatial) tensor

    """
    dat = torch.as_tensor(dat)
    backend = utils.backend(dat)
    voxel_size = utils.make_vector(voxel_size, dat.dim() - 1, **backend)
    voxel_size = voxel_size / voxel_size.mean()

    # initialize
    mask = torch.isnan(dat)
    for channel in dat:
        channel[torch.isnan(channel)] = channel[~torch.isnan(channel)].median()
    weights = dat.new_ones([1, *dat.shape[1:]])

    ll_w = weights.sum(dtype=torch.double)
    ll_x = spatial.membrane(dat, dim=3, voxel_size=voxel_size, weights=weights)
    ll_x = (ll_x * dat).sum(dtype=torch.double)
    if verbose > 0:
        print(f'{0:3d} | {ll_x} + {ll_w} = {ll_x + ll_w}')

    # Reweighted least squares loop
    zeros = torch.zeros_like(dat)
    for n_iter_rls in range(1, max_iter_rls + 1):

        ll_x0 = ll_x
        ll_w0 = ll_w

        grad = spatial.membrane(dat,
                                dim=3,
                                voxel_size=voxel_size,
                                weights=weights)
        grad = grad.masked_select(mask)

        def precond(x):
            p = spatial.membrane_diag(dim=3,
                                      voxel_size=voxel_size,
                                      weights=weights)
            p = p.expand_as(dat).masked_select(mask)
            p = p
            return x / p

        def forward(x):
            m = zeros
            m.masked_scatter_(mask, x)
            m = spatial.membrane(m,
                                 dim=3,
                                 voxel_size=voxel_size,
                                 weights=weights)
            m = m.masked_select(mask)
            return m

        delta = optim.cg(forward,
                         grad,
                         precond=precond,
                         max_iter=max_iter_cg,
                         tolerance=tol_cg,
                         verbose=verbose > 1,
                         stop='norm')
        dat.masked_scatter_(mask, dat.masked_select(mask) - delta)

        weights, ll_w = spatial.membrane_weights(dat,
                                                 dim=3,
                                                 voxel_size=voxel_size,
                                                 return_sum=True)
        ll_x = spatial.membrane(dat,
                                dim=3,
                                voxel_size=voxel_size,
                                weights=weights)
        ll_x = (ll_x * dat).sum(dtype=torch.double)
        if verbose > 0:
            print(f'{n_iter_rls:3d} | {ll_x} + {ll_w} = {ll_x + ll_w}')

        if abs(((ll_x0 - ll_w0) - (ll_x + ll_w)) / (ll_x0 - ll_w0)) < tol_rls:
            if verbose > 0:
                print('Converged')
            break

    return dat
Exemplo n.º 3
0
def greeq(data, transmit=None, receive=None, opt=None, **kwopt):
    """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : GREEQOptions or dict, optional
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'optim':   {'nb_levels':     1,         # Number of pyramid levels
                     'max_iter_rls':  10,        # Max reweighting iterations
                     'max_iter_gn':   5,         # Max Gauss-Newton iterations
                     'max_iter_cg':   32,        # Max Conjugate-Gradient iterations
                     'tolerance': 1e-05,     # Tolerance for early stopping (RLS)
                     'tolerance':  1e-05,         ""
                     'tolerance_cg':  1e-03},        ""
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'penalty': {'norm':    'jtv',           # Type of penalty: {'tkh', 'tv', 'jtv', None}
                     'factor':  {'r1':  10,      # Penalty factor per (log) map
                                 'pd':  10,
                                 'r2s': 2,
                                 'mt':  2}},
         'verbose': 1}

    Returns
    -------
    pd : ParameterMap
        Proton density
    r1 : ParameterMap
        Longitudinal relaxation rate
    r2s : ParameterMap
        Apparent transverse relaxation rate
    mt : ParameterMap, optional
        Magnetisation transfer saturation
        Only returned is MT-weighted data is provided.

    """
    opt = GREEQOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate noise / register / initialize maps ---
    data, transmit, receive, maps = preproc(data, transmit, receive, opt)
    has_mt = hasattr(maps, 'mt')

    # --- prepare penalty factor ---
    lam = opt.penalty.factor
    if isinstance(lam, dict):
        lam = [
            lam.get('pd', 0),
            lam.get('r1', 0),
            lam.get('r2s', 0),
            lam.get('mt', 0)
        ]
        if not has_mt:
            lam = lam[:3]
    lam = core.utils.make_vector(lam, 3 + has_mt,
                                 **backend)  # PD, R1, R2*, [MT]

    # --- initialize weights (RLS) ---
    if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0):
        opt.penalty.norm = ''
    opt.penalty.norm = opt.penalty.norm.lower()
    mean_shape = maps[0].shape
    rls = None
    sumrls = 0
    if opt.penalty.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.penalty.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = torch.ones(rls_shape, **backend)
        sumrls = 0.5 * core.py.prod(rls_shape)

    if opt.penalty.norm:
        print(f'With {opt.penalty.norm.upper()} penalty:')
        print(f'    - PD:  {lam[0]:.3g}')
        print(f'    - R1:  {lam[1]:.3g}')
        print(f'    - R2*: {lam[2]:.3g}')
        if has_mt:
            print(f'    - MT:  {lam[3]:.3g}')
    else:
        print('Without penalty')

    if opt.penalty.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.penalty.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - GN iterations:    {opt.optim.max_iter_gn}')
    if opt.optim.solver == 'fmg':
        print(f'    - FMG cycles:       2')
        print(f'    - CG iterations:    2')
    else:
        print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
              f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    printer = CritPrinter(max_levels=opt.optim.nb_levels,
                          max_rls=opt.optim.max_iter_rls,
                          max_gn=opt.optim.max_iter_gn,
                          penalty=opt.penalty.norm,
                          verbose=opt.verbose)
    printer.print_head()

    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    vx0 = vx = spatial.voxel_size(aff0)
    vol0 = vx0.prod()
    vol = vx.prod() / vol0
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    for level in range(opt.optim.nb_levels, 0, -1):
        printer.level = level

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            vol = vx.prod() / vol0
            maps, rls = resize(maps, rls, aff, shape)
            if opt.penalty.norm in ('tv', 'jtv'):
                sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double)

        # --- compute derivatives ---
        nb_prm = len(maps)
        nb_hes = nb_prm * (nb_prm + 1) // 2
        grad = torch.empty((nb_prm, ) + shape, **backend)
        hess = torch.empty((nb_hes, ) + shape, **backend)

        ll_rls = []
        ll_gn = []
        ll_max = float('inf')

        max_iter_rls = max(opt.optim.max_iter_rls // level, 1)
        for n_iter_rls in range(max_iter_rls):
            # --- Reweighted least-squares loop ---
            printer.rls = n_iter_rls

            # --- Gauss Newton loop ---
            for n_iter_gn in range(opt.optim.max_iter_gn):
                # start = time.time()
                printer.gn = n_iter_gn
                crit = 0
                grad.zero_()
                hess.zero_()
                # --- loop over contrasts ---
                for contrast, b1m, b1p in zip(data, receive, transmit):
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, maps, b1m, b1p, opt)

                    # increment
                    if hasattr(maps, 'mt') and not contrast.mt:
                        # we optimize for mt but this particular contrast
                        # has no information about mt so g1/h1 are smaller
                        # than grad/hess.
                        grad[:-1] += g1
                        hind = list(range(nb_prm - 1))
                        cnt = nb_prm
                        for i in range(nb_prm):
                            for j in range(i + 1, nb_prm):
                                if i != nb_prm - 1 and j != nb_prm - 1:
                                    hind.append(cnt)
                                cnt += 1
                        hess[hind] += h1
                        crit += crit1
                    else:
                        grad += g1
                        hess += h1
                        crit += crit1

                    del g1, h1, crit1
                # duration = time.time() - start
                # print('grad', duration)

                # start = time.time()
                reg = 0.
                if opt.penalty.norm:
                    g = spatial.regulariser(maps.volume,
                                            weights=rls,
                                            dim=3,
                                            voxel_size=vx,
                                            membrane=1,
                                            factor=lam * vol)
                    grad += g
                    reg = 0.5 * dot(maps.volume, g)
                    del g
                # duration = time.time() - start
                # print('reg', duration)

                # --- gauss-newton ---
                # start = time.time()
                grad = check_nans_(grad, warn='gradient')
                hess = check_nans_(hess, warn='hessian')
                if opt.penalty.norm:
                    # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8)  # 1e-5 1e-8
                    if opt.optim.solver == 'fmg':
                        deltas = spatial.solve_field_fmg(hess,
                                                         grad,
                                                         rls,
                                                         factor=lam * vol,
                                                         membrane=1,
                                                         voxel_size=vx,
                                                         nb_iter=2)
                    else:
                        deltas = spatial.solve_field(
                            hess,
                            grad,
                            rls,
                            factor=lam * vol,
                            membrane=1,
                            voxel_size=vx,
                            verbose=max(0, opt.verbose - 1),
                            optim='cg',
                            max_iter=opt.optim.max_iter_cg,
                            tolerance=opt.optim.tolerance_cg,
                            stop='diff')
                else:
                    # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4)
                    deltas = spatial.solve_field_closedform(hess, grad)
                deltas = check_nans_(deltas, warn='deltas')
                # duration = time.time() - start
                # print('solve', duration)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    map.volume.clamp_(-64, 64)  # avoid exp overflow
                    del delta
                del deltas

                # --- Compute gain ---
                ll = crit + reg + sumrls
                ll_max = max(ll_max, ll)
                ll_prev = ll_gn[-1] if ll_gn else ll_max
                gain = ll_prev - ll
                ll_gn.append(ll)
                printer.print_crit(crit, reg, sumrls, gain / ll_scl)
                if gain < opt.optim.tolerance * ll_scl:
                    # print('GN converged: ', ll_prev.item(), '->', ll.item())
                    break

            # --- Update RLS weights ---
            if opt.penalty.norm in ('tv', 'jtv'):
                rls, sumrls = spatial.membrane_weights(
                    maps.volume,
                    lam,
                    vx,
                    return_sum=True,
                    dim=3,
                    joint=opt.penalty.norm == 'jtv',
                    eps=core.constants.eps(rls.dtype))
                sumrls *= 0.5 * vol

                # --- Compute gain ---
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll = ll_gn[-1]
                ll_prev = ll_rls[-1] if ll_rls else ll_max
                ll_rls.append(ll)
                gain = ll_prev - ll
                if abs(gain) < opt.optim.tolerance * ll_scl:
                    # print(f'RLS converged ({gain:7.2g})')
                    break

    del grad
    if opt.uncertainty:
        uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt)
        maps.pd.uncertainty = uncertainty[0]
        maps.r1.uncertainty = uncertainty[1]
        maps.r2s.uncertainty = uncertainty[2]
        if hasattr(maps, 'mt'):
            maps.mt.uncertainty = uncertainty[3]

    # --- Prepare output ---
    return postproc(maps)
Exemplo n.º 4
0
def correct_smooth(x,
                   sigma=None,
                   lam=10,
                   gamma=10,
                   downsample=None,
                   max_iter=16,
                   max_rls=8,
                   tol=1e-6,
                   verbose=False,
                   device=None):
    """Correct the intensity non-uniformity in a SPIM image.

    The signal is modelled as: f = exp(s + b) + eps, with a penalty on
    the (Squared) gradients of s and on the (squared) curvature of b.

    Parameters
    ----------
    x : tensor
        SPIM image with the z dimension last and the z=0 plane first
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation on the signal.
    gamma : float, default=10
        Regularisation on the bias field.
    max_iter : int, default=16
        Maximum number of Newton iterations.
    max_rls : int, default=8
        Maximum number of reweighting iterations.
        If 1, this is effectively an l2 regularisation.
    tol : float, default=1e-6
        Tolerance for early stopping.
    verbose : int or bool, default=False
        Verbosity level
    device : torch.device, default=x.device
        Use this device during fitting.

    Returns
    -------
    y : tensor
        Fitted image
    bias : float
        Fitted bias
    x : float
        Corrected image

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    dim = x.dim()

    # downsampling
    if downsample:
        x0 = x
        downsample = py.make_list(downsample, dim)
        x = spatial.pool(dim, x, downsample)
    shape = x.shape
    x = x.to(device)

    # noise educated guess: assume SNR=5 at z=1/2
    center = tuple(slice(s // 3, 2 * s // 3) for s in shape)
    sigma = sigma or x[center].median() / 5
    lam = lam**2 * sigma**2
    gamma = gamma**2 * sigma**2
    regy = lambda y, w: spatial.regulariser(
        y[None], membrane=lam, dim=dim, weights=w)[0]
    regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0]
    solvey = lambda h, g, w: spatial.solve_field_sym(
        h[None], g[None], membrane=lam, dim=dim, weights=w)[0]
    solveb = lambda h, g: spatial.solve_field_sym(
        h[None], g[None], bending=gamma, dim=dim)[0]

    # init
    l1 = max_rls > 1
    if l1:
        w = torch.ones_like(x)[None]
        llw = w.sum()
        max_rls = 10
    else:
        w = None
        llw = 0
        max_rls = 1
    logb = torch.zeros_like(x)
    logy = x.clamp_min(1e-3).log_()
    y = logy.exp()
    b = logb.exp()
    fit = y * b
    res = fit - x
    llx = res.square().sum()
    lly = (regy(logy, w).mul_(logy)).sum()
    llb = (regb(logb).mul_(logb)).sum()
    ll0 = llx + lly + llb + llw
    ll1 = ll0

    for it_ls in range(max_rls):
        for it in range(max_iter):

            # update bias
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regb(logb)
            logb -= solveb(h, g)
            logb0 = logb.mean()
            logb -= logb0
            logy += logb0

            # update fit / ll
            llb = (regb(logb).mul_(logb)).sum()
            b = torch.exp(logb, out=b)
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x

            # update y
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regy(logy, w)
            logy -= solvey(h, g, w)

            # update fit / ll
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x
            lly = (regy(logy, w).mul_(logy)).sum()

            # compute objective
            llx = res.square().sum()
            ll = llx + lly + llb + llw
            gain = (ll1 - ll) / ll0
            ll1 = ll
            if verbose:
                end = '\n' if verbose > 1 else '\r'
                pre = f'{it_ls:3d} | ' if l1 else ''
                print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}',
                      end=end)
            if it > 0 and abs(gain) < tol:
                break

        if l1:
            w, llw = spatial.membrane_weights(logy[None],
                                              lam,
                                              dim=dim,
                                              return_sum=True)
            ll0 = ll
    if verbose:
        print('')

    if downsample:
        b = spatial.resize(logb.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        y = spatial.resize(logy.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        x = x0
    else:
        y = torch.exp(logy, out=y)
    x = x / b
    return y, b, x
Exemplo n.º 5
0
def run_epic_noreverse(echoes,
                       voxshift=None,
                       lam=1,
                       sigma=1,
                       max_iter=(10, 32),
                       tol=1e-5,
                       verbose=False):
    """Run EPIC on pre-loaded tensors (no reverse gradients or synthesis)

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # initialize denoised echoes
    fit = echoes.clone()
    fwd_fit = torch.empty_like(fit)

    # prepare voxel shift maps
    do_voxshift = voxshift is not None
    if do_voxshift:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if do_voxshift:
        one = torch.empty_like(voxshift)[None]
        h = torch.empty_like(echoes)
        h[0::2] = push1d(pull1d(one, voxshift), voxshift)
        h[1::2] = push1d(pull1d(one, ivoxshift), ivoxshift)
        del one
    else:
        h = fit.new_ones([1] * (nd + 1))
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        ll = ssq(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        g = fwd_fit.mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
Exemplo n.º 6
0
def run_epic(echoes,
             reverse_echoes=None,
             voxshift=None,
             extrapolate=True,
             lam=1,
             sigma=1,
             max_iter=(10, 32),
             tol=1e-5,
             verbose=False):
    """Run EPIC on pre-loaded tensors.

    Parameters
    ----------
    echoes : (N, *spatial) tensor
        Echoes acquired with bipolar readout, Readout direction should be last.
    reverse_echoes : (N, *spatial)
        Echoes acquired with reverse bipolar readout. Else: synthesized.
    voxshift : (*spatial) tensor
        Voxel shift map used to deform towards even (0, 2, ...) echoes.
        Its inverse is used to deform towards odd (1, 3, ...) echoes.
    extrapolate : bool
        Extrapolate first/last echo when reverse_echoes is None.
        Otherwise, only use interpolated echoes.
    lam : [list of] float
        Regularization factor (per echo)
    sigma : float
        Noise standard deviation
    max_iter : [pair of] int
        Maximum number of RLS and CG iterations
    tol : float
        Tolerance for early stopping
    verbose : int,
        Verbosity level

    Returns
    -------
    echoes : (N, *spatial) tensor
        Undistorted + denoised echoes

    """
    if reverse_echoes is False:
        return run_epic_noreverse(echoes, voxshift, lam, sigma, max_iter, tol,
                                  verbose)

    ne = len(echoes)  # number of echoes
    nv = echoes.shape[1:].numel()  # number of voxels
    nd = echoes.dim() - 1  # number of dimensions

    # synthesize echoes
    synth = not torch.is_tensor(reverse_echoes)
    if synth:
        neg = synthesize_neg(echoes[0::2])
        pos = synthesize_pos(echoes[1::2])
        reverse_echoes = torch.stack([x for y in zip(pos, neg) for x in y])
        del pos, neg
    else:
        extrapolate = True

    # initialize denoised echoes
    fit = (echoes + reverse_echoes).div_(2)
    if not extrapolate:
        fit[0] = echoes[0]
        fit[-1] = echoes[-1]
    fwd_fit = torch.zeros_like(fit)
    bwd_fit = torch.zeros_like(fit)

    # prepare voxel shift maps
    if voxshift is not None:
        ivoxshift = add_identity_1d(-voxshift)
        voxshift = add_identity_1d(voxshift)
    else:
        ivoxshift = None

    # prepare parameters
    max_iter, sub_iter = py.make_list(max_iter, 2)
    tol, sub_tol = py.make_list(tol, 2)
    lam = [l / ne for l in py.make_list(lam, ne)]
    isigma2 = 1 / (sigma * sigma)

    # compute hessian once and for all
    if voxshift is not None:
        one = torch.ones_like(voxshift)[None]
        if extrapolate:
            h = push1d(pull1d(one, voxshift), voxshift)
            h += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h = torch.zeros_like(fit)
            h[:-1] += push1d(pull1d(one, voxshift), voxshift)
            h[1:] += push1d(pull1d(one, ivoxshift), ivoxshift)
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
        del one
        weight_(h)
    else:
        h = fit.new_ones([ne] + [1] * nd)
        if extrapolate:
            h *= 2
            weight_ = lambda x: x.mul_(0.5)
            halfweight_ = lambda x: x.mul_(math.sqrt(0.5))
        else:
            h[1:-1] *= 2
            weight_ = lambda x: x[1:-1].mul_(0.5)
            halfweight_ = lambda x: x[1:-1].mul_(math.sqrt(0.5))
    weight_(h)
    h *= isigma2

    loss = float('inf')
    for n_iter in range(max_iter):

        # update weights
        w, jtv = membrane_weights(fit, factor=lam, return_sum=True)

        # gradient of likelihood (forward)
        pull_forward(fit, voxshift, ivoxshift, out=fwd_fit)
        fwd_fit.sub_(echoes)
        halfweight_(fwd_fit)
        ll = ssq(fwd_fit)
        halfweight_(fwd_fit)
        push_forward(fwd_fit, voxshift, ivoxshift, out=fwd_fit)

        # gradient of likelihood (reversed)
        pull_backward(fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)
        if extrapolate:
            bwd_fit.sub_(reverse_echoes)
        else:
            bwd_fit[1:-1].sub_(reverse_echoes[1:-1])
        halfweight_(bwd_fit)
        ll += ssq(bwd_fit)
        halfweight_(bwd_fit)
        push_backward(bwd_fit, voxshift, ivoxshift, extrapolate, out=bwd_fit)

        g = fwd_fit.add_(bwd_fit).mul_(isigma2)
        ll *= 0.5 * isigma2

        # gradient of prior
        g += regulariser(fit, membrane=1, factor=lam, weights=w)

        # solve
        fit -= solve_field(h,
                           g,
                           w,
                           membrane=1,
                           factor=lam,
                           max_iter=sub_iter,
                           tolerance=sub_tol)

        # track objective
        ll, jtv = ll.item() / (ne * nv), jtv.item() / (ne * nv)
        loss, loss_prev = ll + jtv, loss
        if n_iter:
            gain = (loss_prev - loss) / max((loss_max - loss), 1e-8)
        else:
            gain = float('inf')
            loss_max = loss
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(
                f'{n_iter+1:02d} | {ll:12.6g} + {jtv:12.6g} = {loss:12.6g} '
                f'| gain = {gain:12.6g}',
                end=end)
        if gain < tol:
            break

    if verbose == 1:
        print('')

    return fit
Exemplo n.º 7
0
def denoise(image=None,
            lam=1,
            sigma=1,
            max_iter=64,
            sub_iter=32,
            optim='cg',
            tol=1e-5,
            sub_tol=1e-5,
            plot=False,
            jtv=True,
            dim=None,
            **prm):
    """Denoise an image using a (joint) total variation prior.

    This implementation uses a reweighted least squares approach.

    Parameters
    ----------
    image : (..., K, *spatial) tensor, Image to denoise with K channels
    lam : [list of] float, default=1, Regularisation factor
    max_iter : int, default=64, Number of RLS iterations
    sub_iter : int, default=32, Number of relaxation/cg iterations
    optim : {'cg', 'relax', 'fmg+cg', 'fmg+relax'}, default='cg'
    plot : bool, default=False
    jtv : bool, default=True, Joint TV across channels ($\ell_{1,2}$)
    dim : int, default=image.dim()-1

    Returns
    -------
    denoised : (..., K, *spatial) tensor, Denoised image

    """

    if image is None:
        torch.random.manual_seed(1234)
        image = phantoms.augment(phantoms.circle(), fwhm=0)[None]

    image = torch.as_tensor(image)
    dim = dim or (image.dim() - 1)
    nvox = image.shape[-dim:].numel()

    # regularization (prior)
    lam = make_list(lam, image.shape[-dim - 1])
    if jtv:
        lam = [l / image.shape[-dim - 1] for l in lam]
    prm['membrane'] = 1
    prm['factor'] = lam

    # noise variance (likelihood)
    sigma = make_vector(sigma,
                        image.shape[-dim - 1],
                        dtype=image.dtype,
                        device=image.device)
    isigma2 = 1 / (sigma * sigma)

    # solver
    optim, lr = make_list(optim, 2, default=1)
    do_fmg, optim = ('fmg' in optim), ('relax' if 'relax' in optim else 'cg')
    if do_fmg:

        def solve(h, g, w, optim):
            return spatial.solve_field_fmg(h,
                                           g,
                                           dim=dim,
                                           weights=w,
                                           **prm,
                                           optim=optim)
    else:

        def solve(h, g, w, optim):
            return spatial.solve_field(h,
                                       g,
                                       dim=dim,
                                       weights=w,
                                       **prm,
                                       optim=optim,
                                       max_iter=sub_iter,
                                       tolerance=sub_tol)

    # initialize
    denoised = image.clone()

    l_prev = None
    l_max = None
    for n_iter in range(1, max_iter + 1):

        # update weight map / tv loss
        weights, lw = spatial.membrane_weights(denoised,
                                               dim=dim,
                                               return_sum=True,
                                               joint=jtv,
                                               factor=lam)

        # gradient / hessian of least square problem
        ll, g, h = losses.mse(denoised, image, dim=dim, lam=isigma2)
        ll.mul_(nvox), g.mul_(nvox), h.mul_(nvox)
        g += spatial.regulariser(denoised, dim=dim, weights=weights, **prm)

        # solve least square problem
        optim0 = 'cg' if n_iter < 10 else optim  # it seems it helps
        g = solve(h, g, weights, optim0)
        if lr != 1:
            g.mul_(lr)
        denoised -= g

        ll = ll.item() / image.numel()
        lw = lw.item() / image.numel()
        l = ll + lw

        # print objective
        if l_prev is None:
            l_prev = l_max = l
            gain = float('inf')
            print(f'{n_iter:03d} | {ll:12.6g} + {lw:12.6g} = {l:12.6g}',
                  end='\r')
        else:
            gain = (l_prev - l) / max(abs(l_max - l), 1e-8)
            l_prev = l
            print(
                f'{n_iter:03d} | {ll:12.6g} + {lw:12.6g} = {l:12.6g} '
                f'| gain = {gain:12.6g}',
                end='\r')

        if plot and (n_iter - 1) % (max_iter // 10 + 1) == 0:
            import matplotlib.pyplot as plt
            img = image[0, image.shape[1] // 2] if dim == 3 else image[0]
            den = denoised[0,
                           denoised.shape[1] // 2] if dim == 3 else denoised[0]
            wgt = weights[0, weights.shape[1] // 2] if dim == 3 else weights[0]
            plt.subplot(1, 3, 1)
            plt.imshow(img.cpu())
            plt.subplot(1, 3, 2)
            plt.imshow(den.cpu())
            plt.subplot(1, 3, 3)
            plt.imshow(wgt.reciprocal().cpu())
            plt.colorbar()
            plt.show()

        if gain < tol:
            break
    print('')

    return denoised
Exemplo n.º 8
0
def meetup(data, dist=None, opt=None):
    """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion]
        B0-induced distortion fields

    """
    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit, vreg, reg = float('inf'), 0, 0
    ll_rls = crit + vreg + reg
    sumrls_prev = sumrls
    rls_changed = False
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    armijo_prev = 1
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer_prm = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve_prm = lambda H, g: spatial.solve_field(
                H,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Initial update of parameter maps
            # ----------------------------------------------------------
            crit_pre_prm = None
            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i]
                    grad[gind] += g1
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer_prm(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                if n_iter_prm == 1:
                    crit_pre_prm = crit

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.

                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | '
                            f'{crit_prev:12.6g} + {reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve_prm(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                dd = regularizer_prm(deltas)
                dv = dot(dd, maps.volume)
                dd = dot(dd, deltas)
                delta_reg = 0.5 * (dd - 2 * dv)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- track parameter map improvement --------------
                gain = (crit_prev + reg_prev) - (crit + reg)
                if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # ----------------------------------------------------------
            #    Distortion update with line search
            # ----------------------------------------------------------
            max_iter_dist = opt.optim.max_iter_dist
            for n_iter_dist in range(1, max_iter_dist + 1):

                crit = 0
                vreg = 0
                new_crit = 0
                new_vreg = 0

                deltas, dd, dv = [], 0, 0
                # --- loop over contrasts ------------------------------
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):

                    # --- helpers --------------------------------------
                    vxr = distortion.voxel_size[contrast.readout]
                    lam_dist = dict(opt.distortion.factor)
                    lam_dist['factor'] *= vxr**2
                    regularizer_dist = lambda x: spatial.regulariser(
                        x[None],
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)[0]
                    solve_dist = lambda h, g: spatial.solve_field_fmg(
                        h,
                        g,
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)

                    # --- likelihood -----------------------------------
                    crit1, g, h = derivatives_distortion(
                        contrast, distortion, intercept, maps.decay, opt)
                    crit += crit1

                    # --- regularization -------------------------------
                    vol = distortion.volume
                    g1 = regularizer_dist(vol)
                    vreg1 = 0.5 * vol.flatten().dot(g1.flatten())
                    vreg += vreg1
                    g += g1
                    del g1

                    # --- gauss-newton ---------------------------------
                    h = check_nans_(h, warn='hessian (distortion)')
                    h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0]
                    delta = solve_dist(h, g)
                    delta = check_nans_(delta, warn='delta (distortion)')
                    deltas.append(delta)
                    del g, h

                    deltas.append(delta)
                    dd1 = regularizer_dist(delta)
                    dv += dot(dd1, vol)
                    dd1 = dot(dd1, delta)
                    dd += dd1
                    del delta, vol

                # --- track parameters improvement ---------------------
                if opt.verbose and n_iter_dist == 1:
                    gain = crit_pre_prm - (crit + delta_reg)
                    evol = '<=' if gain > 0 else '>'
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | '
                            f'{crit:12.6g} + {reg + delta_reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | '
                            f'gain = {gain / ll_scl:7.2g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- line search ----------------------------------
                reg = reg + delta_reg
                new_crit, new_reg = crit, reg
                armijo, armijo_prev, ok = armijo_prev, 0, False
                maps0 = maps
                for n_ls in range(1, 12 + 1):
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev))
                    armijo_prev = armijo
                    new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv)

                    maps = maps0.deepcopy()
                    max_iter_prm = opt.optim.max_iter_prm
                    for n_iter_prm in range(1, max_iter_prm + 1):

                        crit_prev = new_crit
                        reg_prev = new_reg

                        # --- loop over contrasts ------------------
                        new_crit = 0
                        grad.zero_()
                        hess.zero_()
                        for i, (contrast, intercept, distortion) in enumerate(
                                zip(data, maps.intercepts, dist)):
                            # compute gradient
                            new_crit1, g1, h1 = derivatives_parameters(
                                contrast, distortion, intercept, maps.decay,
                                opt)
                            # increment
                            gind, hind = [i, -1
                                          ], [i,
                                              len(grad) - 1,
                                              len(grad) + i]
                            grad[gind] += g1
                            hess[hind] += h1
                            new_crit += new_crit1
                        del g1, h1

                        # --- regularization -----------------------
                        new_reg = 0.
                        if opt.regularization.norm:
                            g1 = regularizer_prm(maps.volume)
                            new_reg = 0.5 * dot(maps.volume, g1)
                            grad += g1
                            del g1

                        new_gain = (crit_prev + reg_prev) - (new_crit +
                                                             new_reg)
                        if new_gain < opt.optim.tolerance * ll_scl:
                            break

                        # --- gauss-newton -------------------------
                        # Computing the GN step involves solving H\g
                        hess = check_nans_(hess, warn='hessian')
                        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                        delta = solve_prm(hess, grad)
                        delta = check_nans_(delta, warn='delta')

                        dd = regularizer_prm(delta)
                        dv = dot(dd, maps.volume)
                        dd = dot(dd, delta)
                        delta_reg = 0.5 * (dd - 2 * dv)

                        for map, delta1 in zip(maps, delta):
                            map.volume -= delta1
                            if map.min is not None or map.max is not None:
                                map.volume.clamp_(map.min, map.max)
                        del delta

                    if new_crit + new_reg + new_vreg <= crit + reg:
                        ok = True
                        break
                    else:
                        armijo = armijo / 2

                if not ok:
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.add_(delta1, alpha=armijo_prev)
                    armijo_prev = 1
                    maps = maps0
                    new_crit = crit
                    new_vreg = 0
                    del delta
                else:
                    armijo_prev *= 1.5
                new_vreg = vreg + new_vreg

                # --- track distortion improvement ---------------------
                if not ok:
                    evol = '== (x)'
                elif new_crit + new_vreg <= crit + vreg:
                    evol = f'<= ({n_ls:2d})'
                else:
                    evol = f'> ({n_ls:2d})'
                gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg)
                crit, reg, reg = new_crit, new_vreg, new_reg
                if opt.verbose:
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | '
                        f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if not ok:
                    break

                if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = crit + reg + vreg
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'RLS converged ({gain / ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
Exemplo n.º 9
0
def nonlin(data, dist=None, opt=None):
    """Fit the ESTATICS model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion], if opt.distortion.enable
        B0-induced distortion fields

    """
    if opt.distortion.enable:
        return meetup(data, dist, opt)

    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit = float('inf')
    vreg = 0
    reg = 0
    sumrls_prev = sumrls
    rls_changed = False
    ll_gn = ll_rls = float('inf')
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        #   > max_iter_rls == 1 if regularization is not (J)TV
        max_iter_rls = opt.optim.max_iter_rls
        if level != 1:
            max_iter_rls = 1
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve = lambda h, g: spatial.solve_field(
                h,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Update parameter maps
            # ----------------------------------------------------------

            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind = [i, -1]
                    grad[gind] += g1
                    hind = [i, len(grad) - 1, len(grad) + i]
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.
                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {"---":3s} | {"rls":4s} | '
                        f'{crit_prev:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    if opt.distortion.enable:
                        pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                for map, delta in zip(maps, deltas):
                    map.volume -= delta

                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- compute GN gain ----------------------------------
                ll_gn_prev = ll_gn
                ll_gn = crit + reg + vreg + sumrls
                gain = ll_gn_prev - ll_gn
                if opt.verbose:
                    pstr = f'{n_iter_rls:3d} | {n_iter_prm:3d} | '
                    if opt.distortion.enable:
                        pstr += f'{"----":4s} | '
                        pstr += f'{"-"*72:72s} | '
                    else:
                        pstr += f'{"prm":4s} | '
                        pstr += f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} '
                        pstr += f'= {ll_gn:12.6g} | '
                    pstr += f'gain = {gain/ll_scl:7.2g}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                reg = 0.5 * dot(maps.volume, regularizer(maps.volume))
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = ll_gn
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'Converged ({gain/ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out