Example #1
0
 def solve(h, g, w, optim):
     return spatial.solve_field_fmg(h,
                                    g,
                                    dim=dim,
                                    weights=w,
                                    **prm,
                                    optim=optim)
Example #2
0
 def solve_prm(self, hess, grad):
     """Solve Newton step"""
     hess = check_nans_(hess, warn='hessian')
     hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
     delta = spatial.solve_field_fmg(hess,
                                     grad,
                                     self.rls,
                                     **self.lam_prm,
                                     voxel_size=self.voxel_size)
     delta = check_nans_(delta, warn='delta')
     return delta
Example #3
0
 def solve_dist(self, hess, grad, vx, readout):
     """Solve Newton step"""
     hess = check_nans_(hess, warn='hessian')
     hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
     lam = dict(self.lam_dist)
     lam['factor'] = lam['factor'] * (vx[readout]**2)
     delta = spatial.solve_field_fmg(hess,
                                     grad,
                                     **self.lam_dist,
                                     dim=3,
                                     bound=self.DIST_BOUND,
                                     voxel_size=vx)
     delta = check_nans_(delta, warn='delta')
     return delta
Example #4
0
def solve_parameters(hess, grad, rls, lam, vx, opt):
    """Solve the regularized linear system

    Parameters
    ----------
    hess : (2*P+1, *shape) tensor
    grad : (P+1, *shape) tensor
    rls : ([P+1], *shape) tensor or None
    lam : (P,) sequence[float]
    vx : (D,) sequence[float]
    opt : Options

    Returns
    -------
    delta : (P+1, *shape) tensor

    """

    # The ESTATICS Hessian has a very particular form (intercepts do not
    # have cross elements). We therefore need to tell the solver how to operate
    # on it.

    def matvec(m, x):
        m = m.transpose(-1, -4)
        x = x.transpose(-1, -4)
        return hessian_matmul(m, x).transpose(-4, -1)

    def matsolve(m, x):
        m = m.transpose(-1, -4)
        x = x.transpose(-1, -4)
        return hessian_solve(m, x).transpose(-4, -1)

    def matdiag(m, d):
        return m[..., ::2]

    return spatial.solve_field_fmg(hess,
                                   grad,
                                   rls,
                                   factor=lam,
                                   membrane=1,
                                   voxel_size=vx,
                                   verbose=opt.verbose - 1,
                                   nb_iter=opt.optim.max_iter_cg,
                                   tolerance=opt.optim.tolerance_cg,
                                   matvec=matvec,
                                   matsolve=matsolve,
                                   matdiag=matdiag)
Example #5
0
def greeq(data, transmit=None, receive=None, opt=None, **kwopt):
    """Fit a non-linear relaxometry model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : GREEQOptions or dict, optional
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'optim':   {'nb_levels':     1,         # Number of pyramid levels
                     'max_iter_rls':  10,        # Max reweighting iterations
                     'max_iter_gn':   5,         # Max Gauss-Newton iterations
                     'max_iter_cg':   32,        # Max Conjugate-Gradient iterations
                     'tolerance': 1e-05,     # Tolerance for early stopping (RLS)
                     'tolerance':  1e-05,         ""
                     'tolerance_cg':  1e-03},        ""
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'penalty': {'norm':    'jtv',           # Type of penalty: {'tkh', 'tv', 'jtv', None}
                     'factor':  {'r1':  10,      # Penalty factor per (log) map
                                 'pd':  10,
                                 'r2s': 2,
                                 'mt':  2}},
         'verbose': 1}

    Returns
    -------
    pd : ParameterMap
        Proton density
    r1 : ParameterMap
        Longitudinal relaxation rate
    r2s : ParameterMap
        Apparent transverse relaxation rate
    mt : ParameterMap, optional
        Magnetisation transfer saturation
        Only returned is MT-weighted data is provided.

    """
    opt = GREEQOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate noise / register / initialize maps ---
    data, transmit, receive, maps = preproc(data, transmit, receive, opt)
    has_mt = hasattr(maps, 'mt')

    # --- prepare penalty factor ---
    lam = opt.penalty.factor
    if isinstance(lam, dict):
        lam = [
            lam.get('pd', 0),
            lam.get('r1', 0),
            lam.get('r2s', 0),
            lam.get('mt', 0)
        ]
        if not has_mt:
            lam = lam[:3]
    lam = core.utils.make_vector(lam, 3 + has_mt,
                                 **backend)  # PD, R1, R2*, [MT]

    # --- initialize weights (RLS) ---
    if str(opt.penalty.norm).lower().startswith('no') or all(lam == 0):
        opt.penalty.norm = ''
    opt.penalty.norm = opt.penalty.norm.lower()
    mean_shape = maps[0].shape
    rls = None
    sumrls = 0
    if opt.penalty.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.penalty.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = torch.ones(rls_shape, **backend)
        sumrls = 0.5 * core.py.prod(rls_shape)

    if opt.penalty.norm:
        print(f'With {opt.penalty.norm.upper()} penalty:')
        print(f'    - PD:  {lam[0]:.3g}')
        print(f'    - R1:  {lam[1]:.3g}')
        print(f'    - R2*: {lam[2]:.3g}')
        if has_mt:
            print(f'    - MT:  {lam[3]:.3g}')
    else:
        print('Without penalty')

    if opt.penalty.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.penalty.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - GN iterations:    {opt.optim.max_iter_gn}')
    if opt.optim.solver == 'fmg':
        print(f'    - FMG cycles:       2')
        print(f'    - CG iterations:    2')
    else:
        print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
              f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    printer = CritPrinter(max_levels=opt.optim.nb_levels,
                          max_rls=opt.optim.max_iter_rls,
                          max_gn=opt.optim.max_iter_gn,
                          penalty=opt.penalty.norm,
                          verbose=opt.verbose)
    printer.print_head()

    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    vx0 = vx = spatial.voxel_size(aff0)
    vol0 = vx0.prod()
    vol = vx.prod() / vol0
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    for level in range(opt.optim.nb_levels, 0, -1):
        printer.level = level

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            vol = vx.prod() / vol0
            maps, rls = resize(maps, rls, aff, shape)
            if opt.penalty.norm in ('tv', 'jtv'):
                sumrls = 0.5 * vol * rls.reciprocal().sum(dtype=torch.double)

        # --- compute derivatives ---
        nb_prm = len(maps)
        nb_hes = nb_prm * (nb_prm + 1) // 2
        grad = torch.empty((nb_prm, ) + shape, **backend)
        hess = torch.empty((nb_hes, ) + shape, **backend)

        ll_rls = []
        ll_gn = []
        ll_max = float('inf')

        max_iter_rls = max(opt.optim.max_iter_rls // level, 1)
        for n_iter_rls in range(max_iter_rls):
            # --- Reweighted least-squares loop ---
            printer.rls = n_iter_rls

            # --- Gauss Newton loop ---
            for n_iter_gn in range(opt.optim.max_iter_gn):
                # start = time.time()
                printer.gn = n_iter_gn
                crit = 0
                grad.zero_()
                hess.zero_()
                # --- loop over contrasts ---
                for contrast, b1m, b1p in zip(data, receive, transmit):
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, maps, b1m, b1p, opt)

                    # increment
                    if hasattr(maps, 'mt') and not contrast.mt:
                        # we optimize for mt but this particular contrast
                        # has no information about mt so g1/h1 are smaller
                        # than grad/hess.
                        grad[:-1] += g1
                        hind = list(range(nb_prm - 1))
                        cnt = nb_prm
                        for i in range(nb_prm):
                            for j in range(i + 1, nb_prm):
                                if i != nb_prm - 1 and j != nb_prm - 1:
                                    hind.append(cnt)
                                cnt += 1
                        hess[hind] += h1
                        crit += crit1
                    else:
                        grad += g1
                        hess += h1
                        crit += crit1

                    del g1, h1, crit1
                # duration = time.time() - start
                # print('grad', duration)

                # start = time.time()
                reg = 0.
                if opt.penalty.norm:
                    g = spatial.regulariser(maps.volume,
                                            weights=rls,
                                            dim=3,
                                            voxel_size=vx,
                                            membrane=1,
                                            factor=lam * vol)
                    grad += g
                    reg = 0.5 * dot(maps.volume, g)
                    del g
                # duration = time.time() - start
                # print('reg', duration)

                # --- gauss-newton ---
                # start = time.time()
                grad = check_nans_(grad, warn='gradient')
                hess = check_nans_(hess, warn='hessian')
                if opt.penalty.norm:
                    # hess = hessian_sym_loaddiag(hess, 1e-5, 1e-8)  # 1e-5 1e-8
                    if opt.optim.solver == 'fmg':
                        deltas = spatial.solve_field_fmg(hess,
                                                         grad,
                                                         rls,
                                                         factor=lam * vol,
                                                         membrane=1,
                                                         voxel_size=vx,
                                                         nb_iter=2)
                    else:
                        deltas = spatial.solve_field(
                            hess,
                            grad,
                            rls,
                            factor=lam * vol,
                            membrane=1,
                            voxel_size=vx,
                            verbose=max(0, opt.verbose - 1),
                            optim='cg',
                            max_iter=opt.optim.max_iter_cg,
                            tolerance=opt.optim.tolerance_cg,
                            stop='diff')
                else:
                    # hess = hessian_sym_loaddiag(hess, 1e-3, 1e-4)
                    deltas = spatial.solve_field_closedform(hess, grad)
                deltas = check_nans_(deltas, warn='deltas')
                # duration = time.time() - start
                # print('solve', duration)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    map.volume.clamp_(-64, 64)  # avoid exp overflow
                    del delta
                del deltas

                # --- Compute gain ---
                ll = crit + reg + sumrls
                ll_max = max(ll_max, ll)
                ll_prev = ll_gn[-1] if ll_gn else ll_max
                gain = ll_prev - ll
                ll_gn.append(ll)
                printer.print_crit(crit, reg, sumrls, gain / ll_scl)
                if gain < opt.optim.tolerance * ll_scl:
                    # print('GN converged: ', ll_prev.item(), '->', ll.item())
                    break

            # --- Update RLS weights ---
            if opt.penalty.norm in ('tv', 'jtv'):
                rls, sumrls = spatial.membrane_weights(
                    maps.volume,
                    lam,
                    vx,
                    return_sum=True,
                    dim=3,
                    joint=opt.penalty.norm == 'jtv',
                    eps=core.constants.eps(rls.dtype))
                sumrls *= 0.5 * vol

                # --- Compute gain ---
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll = ll_gn[-1]
                ll_prev = ll_rls[-1] if ll_rls else ll_max
                ll_rls.append(ll)
                gain = ll_prev - ll
                if abs(gain) < opt.optim.tolerance * ll_scl:
                    # print(f'RLS converged ({gain:7.2g})')
                    break

    del grad
    if opt.uncertainty:
        uncertainty = compute_uncertainty(hess, rls, lam * vol, vx, opt)
        maps.pd.uncertainty = uncertainty[0]
        maps.r1.uncertainty = uncertainty[1]
        maps.r2s.uncertainty = uncertainty[2]
        if hasattr(maps, 'mt'):
            maps.mt.uncertainty = uncertainty[3]

    # --- Prepare output ---
    return postproc(maps)
Example #6
0
def zcorrect_exp_const(x,
                       decay=None,
                       sigma=None,
                       lam=10,
                       mask=None,
                       max_iter=128,
                       tol=1e-6,
                       verbose=False,
                       snr=5):
    """Correct the z signal decay in a SPIM image.

    The signal is modelled as: f(z) = s * exp(-b * z) + eps
    where z=0 is (arbitrarily) the middle slice, s is the intercept
    and b is the decay coefficient.

    Parameters
    ----------
    x : (..., nz) tensor
        SPIM image with the z dimension last and the z=0 plane first
    decay : float, optional
        Initial guess for decay parameter. Default: educated guess.
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float or (float, float), default=10
        Regularisation.
    max_iter : int, default=128
    tol : float, default=1e-6
    verbose : int or bool, default=False

    Returns
    -------
    y : tensor
        Corrected image
    decay : float
        Decay parameters

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    backend = utils.backend(x)
    shape = x.shape
    dim = x.dim() - 1
    nz = shape[-1]
    b = decay

    x = utils.movedim(x, -1, 0).clone()
    if mask is None:
        mask = torch.isfinite(x) & (x > 0)
    else:
        mask = mask & (torch.isfinite(x) & (x > 0))
    x[~mask] = 0

    # decay educated guess: closed form from two values
    if b is None:
        z1 = 2 * nz // 5
        z2 = 3 * nz // 5
        x1 = x[z1]
        x1 = x1[x1 > 0].median()
        x2 = x[z2]
        x2 = x2[x2 > 0].median()
        z1 = float(z1)
        z2 = float(z2)
        b = (x2.log() - x1.log()) / (z1 - z2)
    y = x[(nz - 1) // 2]
    y = y[y > 0].median().log()

    b = b.item() if torch.is_tensor(b) else b
    y = y.item()
    print(f'init: y = {y}, b = {b}')

    # noise educated guess: assume SNR=5 at z=1/2
    sigma = sigma or (y / snr)
    lam_y, lam_b = py.make_list(lam, 2)
    lam_y = lam_y**2 * sigma**2
    lam_b = lam_b**2 * sigma**2
    reg = lambda t: spatial.regulariser(
        t, membrane=1, dim=dim, factor=(lam_y, lam_b))
    solve = lambda h, g: spatial.solve_field_fmg(
        h, g, membrane=1, dim=dim, factor=(lam_y, lam_b))

    # init
    z = torch.arange(nz, **backend) - (nz - 1) / 2
    z = utils.unsqueeze(z, -1, dim)
    theta = z.new_empty([2, *x.shape[1:]], **backend)
    logy = theta[0].fill_(y)
    b = theta[1].fill_(b)
    y = logy.exp()
    ll0 = (mask * y *
           (-b * z).exp_() - x).square_().sum() + (theta * reg(theta)).sum()
    ll1 = ll0

    g = torch.zeros_like(theta)
    h = theta.new_zeros([3, *theta.shape[1:]])
    for it in range(max_iter):

        # exponentiate
        y = torch.exp(logy, out=y)
        fit = (b * z).neg_().exp_().mul_(y).mul_(mask)
        res = fit - x

        # compute objective
        reg_theta = reg(theta)
        ll = res.square().sum() + (theta * reg_theta).sum()
        gain = (ll1 - ll) / ll0
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end)
        if it > 0 and gain < tol:
            break
        ll1 = ll

        g[0] = (fit * res).sum(0)
        g[1] = -(fit * res * z).sum(0)
        h[0] = (fit * (fit + res.abs())).sum(0)
        h[1] = (fit * (fit + res.abs()) * (z * z)).sum(0)
        h[2] = -(z * fit * fit).sum(0)

        g += reg_theta
        theta -= solve(h, g)

    y = torch.exp(logy, out=y)
    x = x * (b * z).exp_()
    x = utils.movedim(x, 0, -1)
    x = x.reshape(shape)
    return y, b, x
Example #7
0
def phase_fit(magnitude, phase, lam=(0, 1e1), penalty=('membrane', 'bending')):
    """Fit a complex image using a decreasing phase regularization

    Parameters
    ----------
    magnitude : tensor
    phase : tensor
    lam : (float, float)
    penalty : (str, str)

    Returns
    -------
    magnitude : tensor
    phase : tensor

    """

    # estimate noise precision
    sd = estimate_noise(magnitude)[0]['sd']
    prec = 1 / (sd * sd)

    # initialize fit
    fit = magnitude.new_empty([2, *magnitude.shape])
    fit[0] = magnitude
    fit[0].clamp_min_(1e-8).log_()
    fit[1] = mean_phase(phase, magnitude)

    # allocate placeholders
    g = magnitude.new_empty([2, *magnitude.shape])
    h = magnitude.new_empty([2, *magnitude.shape])
    n = magnitude.numel()

    # prepare regularizer options
    prm = dict(
        membrane=[lam[0] * int(penalty[0] == 'membrane'),
                  lam[1] * int(penalty[1] == 'membrane')],
        bending=[lam[0] * int(penalty[0] == 'bending'),
                 lam[1] * int(penalty[1] == 'bending')],
        bound='dct2')

    lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1])
    ll0 = lr0 = factor = float('inf')
    for n_iter in range(20):

        # decrease regularization
        factor, factor_prev = 1 + 10 ** (5 - n_iter), factor
        factor_ratio = factor / factor_prev if n_iter else float('inf')
        prm['membrane'][-1] = lam0['membrane'] * factor
        prm['bending'][-1] = lam0['bending'] * factor

        # compute derivatives
        ll, g, h = derivatives(magnitude, phase, fit[0], fit[1], g, h)
        ll *= prec
        g *= prec
        h *= prec

        # compute regularization
        reg = spatial.regulariser(fit, **prm)
        lr = 0.5 * dot(fit, reg)
        g += reg

        # Gauss-Newton step
        fit -= spatial.solve_field_fmg(h, g, **prm)

        # Compute progress
        l0 = ll0 + factor_ratio * lr0
        l = ll + lr
        gain = l0 - l
        print(f'{n_iter:3d} | {ll/n:12.6g} + {lr/n:6.3g} = {l/n:12.6g} | gain = {gain/n:6.3}')
        if abs(gain) < n * 1e-4:
            break

        ll0, lr0 = ll, lr
        # plot_fit(magnitude, phase, fit)

    return fit[0].exp_(), fit[1]
Example #8
0
def meetup(data, dist=None, opt=None):
    """Fit the ESTATICS+MEETUP model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    dist : sequence[Optional[ParameterizedDistortion]], optional
        Pre-computed distortion fields
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map
    distortions : sequence[ParameterizedDistortion]
        B0-induced distortion fields

    """
    # --- prepare everything -------------------------------------------
    data, maps, dist, opt, rls = _prepare(data, dist, opt)
    sumrls = 0.5 * rls.sum(dtype=torch.double) if rls is not None else 0
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- initialize tracking of the objective function ----------------
    crit, vreg, reg = float('inf'), 0, 0
    ll_rls = crit + vreg + reg
    sumrls_prev = sumrls
    rls_changed = False
    ll_scl = sum(core.py.prod(dat.shape) for dat in data)

    # --- Multi-Resolution loop ----------------------------------------
    shape0 = shape = maps.shape[1:]
    aff0 = aff = maps.affine
    lam0 = lam = opt.regularization.factor
    vx0 = vx = spatial.voxel_size(aff0)
    scl0 = vx0.prod()
    armijo_prev = 1
    for level in range(opt.optim.nb_levels, 0, -1):

        if opt.optim.nb_levels > 1:
            aff, shape = _get_level(level, aff0, shape0)
            vx = spatial.voxel_size(aff)
            scl = vx.prod() / scl0
            lam = [float(l * scl) for l in lam0]
            maps, rls = _resize(maps, rls, aff, shape)
            if opt.regularization.norm in ('tv', 'jtv'):
                sumrls = 0.5 * scl * rls.reciprocal().sum(dtype=torch.double)

        grad = torch.empty((len(data) + 1, *shape), **backend)
        hess = torch.empty((len(data) * 2 + 1, *shape), **backend)

        # --- RLS loop -------------------------------------------------
        max_iter_rls = 1 if level > 1 else opt.optim.max_iter_rls
        for n_iter_rls in range(1, max_iter_rls + 1):

            # --- helpers ----------------------------------------------
            regularizer_prm = lambda x: spatial.regulariser(
                x, weights=rls, dim=3, voxel_size=vx, membrane=1, factor=lam)
            solve_prm = lambda H, g: spatial.solve_field(
                H,
                g,
                rls,
                factor=lam,
                membrane=1 if opt.regularization.norm else 0,
                voxel_size=vx,
                max_iter=opt.optim.max_iter_cg,
                tolerance=opt.optim.tolerance_cg,
                dim=3)
            reweight = lambda x, **k: spatial.membrane_weights(
                x,
                lam,
                vx,
                dim=3,
                **k,
                joint=opt.regularization.norm == 'jtv',
                eps=core.constants.eps(rls.dtype))

            # ----------------------------------------------------------
            #    Initial update of parameter maps
            # ----------------------------------------------------------
            crit_pre_prm = None
            max_iter_prm = opt.optim.max_iter_prm
            for n_iter_prm in range(1, max_iter_prm + 1):

                crit_prev = crit
                reg_prev = reg

                # --- loop over contrasts ------------------------------
                crit = 0
                grad.zero_()
                hess.zero_()
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):
                    # compute gradient
                    crit1, g1, h1 = derivatives_parameters(
                        contrast, distortion, intercept, maps.decay, opt)
                    # increment
                    gind, hind = [i, -1], [i, len(grad) - 1, len(grad) + i]
                    grad[gind] += g1
                    hess[hind] += h1
                    crit += crit1
                del g1, h1

                # --- regularization -----------------------------------
                reg = 0.
                if opt.regularization.norm:
                    g1 = regularizer_prm(maps.volume)
                    reg = 0.5 * dot(maps.volume, g1)
                    grad += g1
                    del g1

                if n_iter_prm == 1:
                    crit_pre_prm = crit

                # --- track RLS improvement ----------------------------
                # Updating the RLS weights changes `sumrls` and `reg`. Now
                # that we have the updated value of `reg` (before it is
                # modified by the map update), we can check that updating the
                # weights indeed improved the objective.

                if opt.verbose and rls_changed:
                    rls_changed = False
                    if reg + sumrls <= reg_prev + sumrls_prev:
                        evol = '<='
                    else:
                        evol = '>'
                    ll_tmp = crit_prev + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls-1:3d} | {"---":3s} | {"rls":4s} | '
                            f'{crit_prev:12.6g} + {reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- gauss-newton -------------------------------------
                # Computing the GN step involves solving H\g
                hess = check_nans_(hess, warn='hessian')
                hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                deltas = solve_prm(hess, grad)
                deltas = check_nans_(deltas, warn='delta')

                dd = regularizer_prm(deltas)
                dv = dot(dd, maps.volume)
                dd = dot(dd, deltas)
                delta_reg = 0.5 * (dd - 2 * dv)

                for map, delta in zip(maps, deltas):
                    map.volume -= delta
                    if map.min is not None or map.max is not None:
                        map.volume.clamp_(map.min, map.max)
                del deltas

                # --- track parameter map improvement --------------
                gain = (crit_prev + reg_prev) - (crit + reg)
                if n_iter_prm > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # ----------------------------------------------------------
            #    Distortion update with line search
            # ----------------------------------------------------------
            max_iter_dist = opt.optim.max_iter_dist
            for n_iter_dist in range(1, max_iter_dist + 1):

                crit = 0
                vreg = 0
                new_crit = 0
                new_vreg = 0

                deltas, dd, dv = [], 0, 0
                # --- loop over contrasts ------------------------------
                for i, (contrast, intercept, distortion) \
                        in enumerate(zip(data, maps.intercepts, dist)):

                    # --- helpers --------------------------------------
                    vxr = distortion.voxel_size[contrast.readout]
                    lam_dist = dict(opt.distortion.factor)
                    lam_dist['factor'] *= vxr**2
                    regularizer_dist = lambda x: spatial.regulariser(
                        x[None],
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)[0]
                    solve_dist = lambda h, g: spatial.solve_field_fmg(
                        h,
                        g,
                        **lam_dist,
                        dim=3,
                        bound=DIST_BOUND,
                        voxel_size=distortion.voxel_size)

                    # --- likelihood -----------------------------------
                    crit1, g, h = derivatives_distortion(
                        contrast, distortion, intercept, maps.decay, opt)
                    crit += crit1

                    # --- regularization -------------------------------
                    vol = distortion.volume
                    g1 = regularizer_dist(vol)
                    vreg1 = 0.5 * vol.flatten().dot(g1.flatten())
                    vreg += vreg1
                    g += g1
                    del g1

                    # --- gauss-newton ---------------------------------
                    h = check_nans_(h, warn='hessian (distortion)')
                    h = hessian_loaddiag_(h[None], 1e-32, 1e-32, sym=True)[0]
                    delta = solve_dist(h, g)
                    delta = check_nans_(delta, warn='delta (distortion)')
                    deltas.append(delta)
                    del g, h

                    deltas.append(delta)
                    dd1 = regularizer_dist(delta)
                    dv += dot(dd1, vol)
                    dd1 = dot(dd1, delta)
                    dd += dd1
                    del delta, vol

                # --- track parameters improvement ---------------------
                if opt.verbose and n_iter_dist == 1:
                    gain = crit_pre_prm - (crit + delta_reg)
                    evol = '<=' if gain > 0 else '>'
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (f'{n_iter_rls:3d} | {"---":3} | {"prm":4s} | '
                            f'{crit:12.6g} + {reg + delta_reg:12.6g} + '
                            f'{sumrls:12.6g} + {vreg:12.6g} '
                            f'= {ll_tmp:12.6g} | '
                            f'gain = {gain / ll_scl:7.2g} | {evol}')
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)

                # --- line search ----------------------------------
                reg = reg + delta_reg
                new_crit, new_reg = crit, reg
                armijo, armijo_prev, ok = armijo_prev, 0, False
                maps0 = maps
                for n_ls in range(1, 12 + 1):
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.sub_(delta1, alpha=(armijo - armijo_prev))
                    armijo_prev = armijo
                    new_vreg = 0.5 * armijo * (armijo * dd - 2 * dv)

                    maps = maps0.deepcopy()
                    max_iter_prm = opt.optim.max_iter_prm
                    for n_iter_prm in range(1, max_iter_prm + 1):

                        crit_prev = new_crit
                        reg_prev = new_reg

                        # --- loop over contrasts ------------------
                        new_crit = 0
                        grad.zero_()
                        hess.zero_()
                        for i, (contrast, intercept, distortion) in enumerate(
                                zip(data, maps.intercepts, dist)):
                            # compute gradient
                            new_crit1, g1, h1 = derivatives_parameters(
                                contrast, distortion, intercept, maps.decay,
                                opt)
                            # increment
                            gind, hind = [i, -1
                                          ], [i,
                                              len(grad) - 1,
                                              len(grad) + i]
                            grad[gind] += g1
                            hess[hind] += h1
                            new_crit += new_crit1
                        del g1, h1

                        # --- regularization -----------------------
                        new_reg = 0.
                        if opt.regularization.norm:
                            g1 = regularizer_prm(maps.volume)
                            new_reg = 0.5 * dot(maps.volume, g1)
                            grad += g1
                            del g1

                        new_gain = (crit_prev + reg_prev) - (new_crit +
                                                             new_reg)
                        if new_gain < opt.optim.tolerance * ll_scl:
                            break

                        # --- gauss-newton -------------------------
                        # Computing the GN step involves solving H\g
                        hess = check_nans_(hess, warn='hessian')
                        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
                        delta = solve_prm(hess, grad)
                        delta = check_nans_(delta, warn='delta')

                        dd = regularizer_prm(delta)
                        dv = dot(dd, maps.volume)
                        dd = dot(dd, delta)
                        delta_reg = 0.5 * (dd - 2 * dv)

                        for map, delta1 in zip(maps, delta):
                            map.volume -= delta1
                            if map.min is not None or map.max is not None:
                                map.volume.clamp_(map.min, map.max)
                        del delta

                    if new_crit + new_reg + new_vreg <= crit + reg:
                        ok = True
                        break
                    else:
                        armijo = armijo / 2

                if not ok:
                    for delta1, dist1 in zip(deltas, dist):
                        dist1.volume.add_(delta1, alpha=armijo_prev)
                    armijo_prev = 1
                    maps = maps0
                    new_crit = crit
                    new_vreg = 0
                    del delta
                else:
                    armijo_prev *= 1.5
                new_vreg = vreg + new_vreg

                # --- track distortion improvement ---------------------
                if not ok:
                    evol = '== (x)'
                elif new_crit + new_vreg <= crit + vreg:
                    evol = f'<= ({n_ls:2d})'
                else:
                    evol = f'> ({n_ls:2d})'
                gain = (crit + vreg + reg) - (new_crit + new_vreg + new_reg)
                crit, reg, reg = new_crit, new_vreg, new_reg
                if opt.verbose:
                    ll_tmp = crit + reg + vreg + sumrls
                    pstr = (
                        f'{n_iter_rls:3d} | {n_iter_dist:3d} | {"dist":4s} | '
                        f'{crit:12.6g} + {reg:12.6g} + {sumrls:12.6g} ')
                    pstr += f'+ {vreg:12.6g} '
                    pstr += f'= {ll_tmp:12.6g} | gain = {gain / ll_scl:7.2g} | '
                    pstr += f'{evol}'
                    if opt.optim.nb_levels > 1:
                        pstr = f'{level:3d} | ' + pstr
                    print(pstr)
                    if opt.plot:
                        _show_maps(maps, dist, data)
                if not ok:
                    break

                if n_iter_dist > 1 and gain < opt.optim.tolerance * ll_scl:
                    break

            # --------------------------------------------------------------
            #    Update RLS weights
            # --------------------------------------------------------------
            if level == 1 and opt.regularization.norm in ('tv', 'jtv'):
                rls_changed = True
                sumrls_prev = sumrls
                rls, sumrls = reweight(maps.volume, return_sum=True)
                sumrls = 0.5 * sumrls

                # --- compute gain -----------------------------------------
                # (we are late by one full RLS iteration when computing the
                #  gain but we save some computations)
                ll_rls_prev = ll_rls
                ll_rls = crit + reg + vreg
                gain = ll_rls_prev - ll_rls
                if gain < opt.optim.tolerance * ll_scl:
                    print(f'RLS converged ({gain / ll_scl:7.2g})')
                    break

    # --- prepare output -----------------------------------------------
    out = postproc(maps, data)
    if opt.distortion.enable:
        out = (*out, dist)
    return out
Example #9
0
def sin_b1(x,
           fa,
           lam=(0, 1e3),
           penalty=('m', 'b'),
           chi=True,
           pd=None,
           b1=None):
    """Estimate B1+ from Variable flip angle data with long TR

    Parameters
    ----------
    x : (C, *spatial) tensor
        Input flash images
    fa : (C,) sequence[float]
        Flip angle (in deg)
    lam : (float, float), default=(0, 1e4)
        Regularization value for the T2*w Signal, T1  map and B1 map.
    penalty : 3x {'membrane', 'bending'}, default=('m', 'b')
        Regularization type for the T2*w Signal, T1  map and B1 map.

    Returns
    -------
    s : (*spatial) tensor
        T2*-weighted PD map
    b1 : (*spatial) tenmsor
        B1+ map

    """

    fa = utils.make_vector(fa, len(x), dtype=torch.double)
    fa = fa.mul_(pymath.pi / 180).tolist()
    lam = py.make_list(lam, 2)
    penalty = py.make_list(penalty, 2)

    sd, df, mu = 0, 0, 0
    for x1 in x:
        bg, fg = estimate_noise(x1, chi=True)
        sd += bg['sd'].log()
        df += bg['dof'].log()
        mu += fg['mean'].log()
    sd = (sd / len(x)).exp()
    df = (df / len(x)).exp()
    mu = (mu / len(x)).exp()
    prec = 1 / (sd * sd)
    if not chi:
        df = 1

    # mask low SNR voxels
    # x = x * (x > 5 * sd)

    shape = x.shape[1:]
    theta = x.new_empty([2, *shape])
    theta[0] = mu.log() if pd is None else pd.log()
    theta[1] = 0 if b1 is None else b1.log()
    n = (x != 0).sum()

    g = torch.zeros_like(theta)
    h = theta.new_zeros([3, *theta.shape[1:]])
    g1 = torch.zeros_like(theta)
    h1 = theta.new_zeros([3, *theta.shape[1:]])

    prm = dict(
        membrane=[
            lam[0] if penalty[0][0] == 'm' else 0,
            lam[1] if penalty[1][0] == 'm' else 0
        ],
        bending=[
            lam[0] if penalty[0][0] == 'b' else 0,
            lam[1] if penalty[1][0] == 'b' else 0
        ],
    )

    lam0 = dict(membrane=prm['membrane'][-1], bending=prm['bending'][-1])
    ll0 = lr0 = factor = float('inf')
    for n_iter in range(32):

        if n_iter == 1:
            ll0 = lr0 = float('inf')

        # decrease regularization
        factor, factor_prev = 1 + 10**(5 - n_iter), factor
        factor_ratio = factor / factor_prev if n_iter else float('inf')
        prm['membrane'][-1] = lam0['membrane'] * factor
        prm['bending'][-1] = lam0['bending'] * factor

        # derivatives of likelihood term
        df1 = 1 if n_iter == 0 else df
        ll, g, h = sin_full_derivatives(x, fa, theta[0], theta[1], prec, df1,
                                        g, h, g1, h1)

        # derivatives of regularization term
        reg = spatial.regulariser(theta, **prm)
        g += reg
        lr = 0.5 * dot(theta, reg)

        l, l0 = ll + lr, ll0 + factor_ratio * lr0
        gain = l0 - l
        print(
            f'{n_iter:2d} | {ll/n:12.6g} + {lr/n:12.6g} = {l/n:12.6g} | gain = {gain/n:12.6g}'
        )

        # Gauss-Newton update
        h[:2] += 1e-8 * h[:2].abs().max(0).values
        h[:2] += 1e-5
        delta = spatial.solve_field_fmg(h, g, **prm)

        mx = delta.abs().max()
        if mx > 64:
            delta *= 64 / mx

        # theta -= delta
        # ll0, lr0 = ll, lr

        # line search
        dd = spatial.regulariser(delta, **prm)
        dt = dot(dd, theta)
        dd = dot(dd, delta)
        success = False
        armijo = 1
        theta0 = theta
        ll0, lr0 = ll, lr
        for n_ls in range(12):
            theta = theta0.sub(delta, alpha=armijo)
            ll = sin_nll(x, fa, theta[0], theta[1], prec, df1)
            lr = 0.5 * armijo * (armijo * dd - 2 * dt)
            if ll + lr < ll0:  # and theta[1].max() < 0.69:
                print(n_ls, 'success', ((ll + lr) / n).item(),
                      (ll0 / n).item())
                success = True
                break
            print(n_ls, 'failure', ((ll + lr) / n).item(), (ll0 / n).item())
            armijo /= 2
        if not success and n_iter > 5:
            theta = theta0
            break

        import matplotlib.pyplot as plt

        # plt.subplot(1, 2, 1)
        # plt.imshow(theta[0, :, :, theta.shape[-1]//2].exp())
        # plt.colorbar()
        # plt.subplot(1, 2, 2)
        # plt.imshow(theta[1, :, :, theta.shape[-1]//2].exp())
        # plt.colorbar()
        # plt.show()

        plt.rcParams["figure.figsize"] = (4, len(x))
        vmin, vmax = 0, 2 * mu
        y = sin_signals(fa, *theta)
        ex = 3
        for i in range(len(x)):
            plt.subplot(len(x) + ex, 4, 4 * i + 1)
            plt.imshow(x[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 2)
            plt.imshow(y[i, ..., x.shape[-1] // 2], vmin=vmin, vmax=vmax)
            plt.axis('off')
            plt.subplot(len(x) + ex, 4, 4 * i + 3)
            plt.imshow(x[i, ..., x.shape[-1] // 2] -
                       y[i, ..., y.shape[1] // 2],
                       cmap=plt.get_cmap('coolwarm'))
            plt.axis('off')
            plt.colorbar()
            plt.subplot(len(x) + ex, 4, 4 * i + 4)
            plt.imshow(
                (theta[-1, ..., x.shape[-1] // 2].exp() * fa[i]) / pymath.pi)
            plt.axis('off')
            plt.colorbar()
        all_fa = torch.linspace(0, 2 * pymath.pi, 512)
        loc = [(theta.shape[1] // 2, theta.shape[2] // 2, theta.shape[3] // 2),
               (2 * theta.shape[1] // 3, theta.shape[2] // 2,
                theta.shape[3] // 2),
               (theta.shape[1] // 3, theta.shape[2] // 3, theta.shape[3] // 2)]
        for j, (nx, ny, nz) in enumerate(loc):
            plt.subplot(len(x) + ex, 1, len(x) + j + 1)
            plt.plot(all_fa, sin_signals(all_fa, *theta[:, nx, ny, nz]))
            plt.scatter(fa, x[:, nx, ny, nz])
        plt.show()

        # if gain < 1e-4 * n:
        #     break

    theta.exp_()
    return theta[0], theta[1]