Esempio n. 1
0
 def __init__(self, *args, **kwargs):
     self._greeq_volume = None
     self._greeq_affine = None
     super().__init__(*args, **kwargs)
     if len(self) not in (3, 4):
         raise ValueError('Exected 3 or 4 maps')
     self.pd = ParameterMap(self.volume[0], affine=self.affine)
     self.r1 = ParameterMap(self.volume[1], affine=self.affine)
     self.r2s = ParameterMap(self.volume[2], affine=self.affine)
     if len(self) == 4:
         self.mt = ParameterMap(self.volume[3], affine=self.affine)
Esempio n. 2
0
 def intercepts(self):
     if self.volume is None:
         return []
     return [
         ParameterMap(self.volume[i], affine=self.affine)
         for i in range(len(self) - 1)
     ]
Esempio n. 3
0
    def fit(self, data):

        # --- be polite ------------------------------------------------
        if self.opt.verbose > 0:
            print(f'Fitting a (multi) exponential decay model with '
                  f'{len(data)} contrasts. Echo times:')
            for i, contrast in enumerate(data):
                print(f'    - contrast {i:2d}: [' +
                      ', '.join([f'{te * 1e3:.1f}'
                                 for te in contrast.te]) + '] ms')

        # --- estimate noise / register / initialize maps --------------
        self.data, self.maps, self.dist = preproc(data, self.opt)
        self.affine0 = self.maps.affine
        self.shape0 = self.maps.decay.shape

        # --- prepare regularization factor ----------------------------
        # -> we want lam = [*lam_intercepts, lam_decay]
        *lam, lam_decay = self.opt.regularization.factor
        lam = core.py.ensure_list(lam, self.nb_contrasts)
        lam.append(lam_decay)
        if not any(lam):
            self.opt.regularization.norm = ''
        self.lam = lam

        # --- initialize weights (RLS) ---------------------------------
        self.rls = None
        if self.opt.regularization.norm.endswith('tv'):
            rls_shape = self.shape
            if self.opt.regularization.norm == 'tv':
                rls_shape = (len(self.maps), *rls_shape)
            self.rls = ParameterMap(rls_shape, fill=1, **self.backend).volume

        # --- initialize nb of iterations ------------------------------
        if not self.opt.regularization.norm.endswith('tv'):
            # no reweighting -> do more gauss-newton updates instead
            self.opt.optim.max_iter_gn *= self.opt.optim.max_iter_rls
            self.opt.optim.max_iter_rls = 1

        # --- be polite (bis) ------------------------------------------
        self.print_opt()
        self.print_header()

        # --- main loop ------------------------------------------------
        self.loop()

        # --- prepare output -----------------------------------------------
        out = postproc(self.maps, self.data)
        if self.opt.distortion.enable:
            out = (*out, self.dist)
        return out
Esempio n. 4
0
def preproc(data, opt):
    """Estimate noise variance + register + compute recon space + init maps

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
    opt : Options

    Returns
    -------
    data : sequence[GradientEchoMulti]
    maps : ParametersMaps

    """

    if opt is None:
        opt = ESTATICSOptions()

    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate hyper parameters ---
    logmeans = []
    te = []
    for c, contrast in enumerate(data):
        means = []
        vars = []
        for e, echo in enumerate(contrast):
            if opt.verbose:
                print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}', end='\r')
            dat = echo.fdata(**backend, rand=True, cache=False)
            sd0, sd1, mu0, mu1 = estimate_noise(dat)
            echo.mean = mu1.item()
            echo.sd = sd0.item()
            means.append(mu1)
            vars.append(sd0.square())
        means = torch.stack(means)
        vars = torch.stack(vars)
        var = (means*vars).sum() / means.sum()
        contrast.noise = var.item()

        te.append(contrast.te)
        logmeans.append(means.log())
    if opt.verbose:
        print('')

    # --- initial minifit ---
    print('Compute initial parameters')
    inter, decay = _loglin_minifit(logmeans, te)
    print('    - log intercepts: [' + ', '.join([f'{i:.1f}' for i in inter.tolist()]) + ']')
    print(f'    - decay:          {decay.tolist():.3g}')

    # --- initial align ---
    if opt.preproc.register and len(data) > 1:
        print('Register volumes')
        data_reg = [(contrast.echo(0).fdata(rand=True, cache=False, **backend),
                     contrast.affine) for contrast in data]
        dats, affines, _ = affine_align(data_reg, device=device)
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i+1)
                plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu())
            plt.show()
        for contrast, aff in zip(data, affines):
            aff, contrast.affine = core.utils.to_max_device(aff, contrast.affine)
            contrast.affine = torch.matmul(aff.inverse(), contrast.affine)

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape[1:] for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- allocate maps ---
    maps = ESTATICSParameterMaps()
    maps.intercepts = [ParameterMap(mean_shape, fill=inter[c], affine=mean_affine, **backend)
                       for c in range(len(data))]
    maps.decay = ParameterMap(mean_shape, fill=decay, affine=mean_affine, min=0, **backend)
    maps.affine = mean_affine

    return data, maps
Esempio n. 5
0
def gre(pd,
        r1,
        r2s=None,
        mt=None,
        transmit=None,
        receive=None,
        gfactor=None,
        te=0,
        tr=25e-3,
        fa=20,
        mtpulse=False,
        sigma=None,
        noise='rician',
        affine=None,
        shape=None,
        device=None):
    """Simulate data generated by a Gradient-Echo (FLASH) sequence.

    Tissue parameters
    -----------------
    pd : ParameterMap or tensor_like
        Proton density
    r1 : ParameterMap or tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : ParameterMap, optional
        Transverse relaxation rate, in 1/sec. Mandatory if any `te > 0`.
    mt : ParameterMap, optional
        MTsat. Mandatory if any `mtpulse == True`.

    Fields
    ------
    transmit : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        Transmit B1 field
    receive : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        Receive B1 field
    gfactor : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    te : ((N-sequence of) M-sequence of) float, default=0
        Echo time, in sec
    tr : (N-sequence of) float default=2.5e-3
        Repetition time, in sec
    fa : (N-sequence of) float, default=20
        Flip angle, in deg
    mtpulse : (N-sequence of) bool, default=False
        Presence of an off-resonance pulse

    Noise
    -----
    sigma : (N-sequence of) float, optional
        Standard-deviation of the sampled noise (no sampling if `None`)
    noise : {'rician', 'gaussian'}, default='rician'
        Noise distribution

    Space
    -----
    affine : ([N], 4, 4) tensor, optional
        Orientation matrix of the simulation space
    shape : (N-sequence of) sequence[int], default=pd.shape
        Shape of the simulation space

    Returns
    -------
    sim : (N-sequence of) GradientEchoMulti
        Simulated series of multi-echo GRE images

    """

    # 1) Find out the number of contrasts requested
    te = make_list(te)
    if any(map(lambda x: isinstance(x, (list, tuple)), te)):
        te = [make_list(t) for t in te]
    else:
        te = [te]
    tr = make_list(tr)
    fa = make_list(fa)
    mtpulse = make_list(mtpulse)
    mtpulse = [bool(p) for p in mtpulse]
    sigma = make_list(sigma)
    transmit = make_list(transmit or [])
    receive = make_list(receive or [])
    gfactor = make_list(gfactor or [])
    shape = make_list(shape)
    if any(map(lambda x: isinstance(x, (list, tuple)), shape)):
        shape = [make_list(s) for s in shape]
    else:
        shape = [shape]
    if torch.is_tensor(affine):
        affine = [affine] if affine.dim() == 2 else affine.unbind(0)
    else:
        affine = make_list(affine)
    nb_contrasts = max(len(te), len(tr), len(fa), len(mtpulse), len(sigma),
                       len(transmit), len(receive), len(gfactor), len(shape),
                       len(affine))

    # 2) Pad all lists up to `nb_contrasts`
    te = make_list(te, nb_contrasts)
    tr = make_list(tr, nb_contrasts)
    fa = make_list(fa, nb_contrasts)
    mtpulse = make_list(mtpulse, nb_contrasts)
    sigma = make_list(sigma, nb_contrasts)
    transmit = make_list(transmit or [None], nb_contrasts)
    receive = make_list(receive or [None], nb_contrasts)
    gfactor = make_list(gfactor or [None], nb_contrasts)
    shape = make_list(shape, nb_contrasts)
    affine = make_list(affine, nb_contrasts)

    # 3) ensure parameters are `ParameterMap`s
    has_r2s = r2s is not None
    has_mt = mt is not None
    if not isinstance(pd, ParameterMap):
        pd = ParameterMap(pd)
    if not isinstance(r1, ParameterMap):
        r1 = ParameterMap(r1)
    if has_r2s and not isinstance(r2s, ParameterMap):
        r2s = ParameterMap(r2s)
    if has_mt and not isinstance(mt, ParameterMap):
        mt = ParameterMap(mt)
        mt.unit = '%'

    # 4) ensure all fields are `PrecomputedFieldMap`s
    for n in range(nb_contrasts):
        if (transmit[n] is not None
                and not isinstance(transmit[n], PrecomputedFieldMap)):
            transmit[n] = PrecomputedFieldMap(transmit[n])
        if (receive[n] is not None
                and not isinstance(receive[n], PrecomputedFieldMap)):
            receive[n] = PrecomputedFieldMap(receive[n])
        if (gfactor[n] is not None
                and not isinstance(gfactor[n], PrecomputedFieldMap)):
            gfactor[n] = PrecomputedFieldMap(gfactor[n])

    # 5) choose backend
    all_var = [te, tr, fa, mtpulse, sigma, affine]
    all_var += [
        f.volume for f in transmit
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [
        f.volume for f in receive
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [
        f.volume for f in gfactor
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [pd.volume] if torch.is_tensor(pd.volume) else []
    all_var += [r1.volume] if torch.is_tensor(r1.volume) else []
    all_var += [r2s.volume
                ] if r2s is not None and torch.is_tensor(r2s.volume) else []
    all_var += [mt.volume
                ] if mt is not None and torch.is_tensor(mt.volume) else []
    backend = core.utils.max_backend(*all_var)
    if device:
        backend['device'] = device

    # 6) prepare parameter maps
    prm = stack_maps(pd.fdata(**backend), r1.fdata(**backend),
                     r2s.fdata(**backend) if r2s is not None else None,
                     mt.fdata(**backend) if mt is not None else None)
    fpd, fr1, fr2s, fmt = unstack_maps(prm, has_r2s, has_mt)
    if has_mt and mt.unit in ('%', 'pct', 'p.u.'):
        fmt /= 100.
    if any(aff is not None for aff in affine):
        logprm = stack_maps(
            safelog(fpd), safelog(fr1),
            safelog(fr2s) if r2s is not None else None,
            safelog(fmt) + safelog(1 - fmt) if mt is not None else None)

    # 7) generate noise-free signal
    contrasts = []
    for n in range(nb_contrasts):

        shape1 = shape[n]
        if shape1 is None:
            shape1 = pd.shape
        aff1 = affine[n]
        if aff1 is not None:
            aff1 = aff1.to(**backend)
        te1 = torch.as_tensor(te[n], **backend)
        tr1 = torch.as_tensor(tr[n], **backend)
        fa1 = torch.as_tensor(fa[n], **backend) / 180. * core.constants.pi
        sigma1 = torch.as_tensor(sigma[n], **backend) if sigma[n] else None
        mtpulse1 = mtpulse[n]
        transmit1 = transmit[n]
        receive1 = receive[n]
        gfactor1 = gfactor[n]

        if aff1 is not None:
            mat = core.linalg.lmdiv(pd.affine.to(**backend), aff1)
            grid = smart_grid(mat, shape1, pd.shape, force=True)
            prm1 = smart_pull(logprm, grid)
            inplace = grid is not None
            f1pd, f1r1, f1r2s, f1mt = unstack_maps(prm1, has_r2s, has_mt)
            f1pd, f1r1, f1r2s, f1mt = exp_maps(f1pd,
                                               f1r1,
                                               f1r2s,
                                               f1mt,
                                               inplace=inplace)
        else:
            f1pd, f1r1, f1r2s, f1mt = (fpd, fr1, fr2s, fmt)
            # clone it so that we can work in-place later
            f1pd = f1pd.clone()
            f1r1 = f1r1.clone()
            if f1r2s is not None:
                f1r2s = f1r2s.clone()
            if f1mt is not None:
                f1mt = f1mt.clone()

        if transmit1 is not None:
            unit = transmit1.unit
            taff1 = transmit1.affine.to(**backend)
            transmit1 = transmit1.fdata(**backend)
            if unit in ('%', 'pct', 'p.u'):
                transmit1 = transmit1 / 100.
            if aff1 is not None:
                mat = core.linalg.lmdiv(taff1, aff1)
                grid = smart_grid(mat, shape1, transmit1.shape)
                transmit1 = smart_pull(transmit1[None], grid)[0]
                del grid
            fa1 = fa1 * transmit1
            del transmit1

        if receive1 is not None:
            unit = receive1.unit
            raff1 = receive1.affine.to(**backend)
            receive1 = receive1.fdata(**backend)
            if unit in ('%', 'pct', 'p.u'):
                receive1 = receive1 / 100.
            if aff1 is not None:
                mat = core.linalg.lmdiv(raff1, aff1)
                grid = smart_grid(mat, shape1, receive1.shape)
                receive1 = smart_pull(receive1[None], grid)[0]
                del grid
            f1pd = f1pd * receive1
            del receive1

        # generate signal
        flash = f1pd
        flash *= fa1.sin()
        cosfa = fa1.cos_()
        e1 = f1r1
        e1 *= -tr1
        e1 = e1.exp_()
        flash *= (1 - e1)
        if mtpulse1:
            if not has_mt:
                raise ValueError('Cannot simulate an MT pulse: '
                                 'an MT must be provided.')
            omt = f1mt.neg_()
            omt += 1
            flash *= omt
            flash /= (1 - cosfa * omt * e1)
            del omt
        else:
            flash /= (1 - cosfa * e1)
        del e1, cosfa, fa1, f1r1, f1mt

        # multiply with r2*
        if any(t > 0 for t in te1):
            if not has_r2s:
                raise ValueError('Cannot simulate an R2* decay: '
                                 'an R2* must be provided.')
            te1 = te1.reshape([-1] + [1] * f1r2s.dim())
            flash = flash * (-te1 * f1r2s).exp_()
            del f1r2s

        # sample noise
        if sigma1:
            if gfactor1 is not None:
                gfactor1 = gfactor1.fdata(**backend)
                if gfactor1.unit in ('%', 'pct', 'p.u'):
                    gfactor1 = gfactor1 / 100.
                if aff1 is not None:
                    mat = core.linalg.lmdiv(gfactor1.affine.to(**backend),
                                            aff1)
                    grid = smart_grid(mat, shape1, gfactor1.shape)
                    gfactor1 = smart_pull(gfactor1[None], grid)[0]
                    del grid
                sigma1 = sigma1 * gfactor1
                del gfactor1

            noise_shape = flash.shape
            if noise == 'rician':
                noise_shape = (2, ) + noise_shape
            sample = torch.randn(noise_shape, **backend)
            sample *= sigma1
            del sigma1
            if noise == 'rician':
                sample = sample.square_().sum(dim=0)
                flash = flash.square_().add_(sample).sqrt_()
            else:
                flash += sample
            del sample

        te1 = te1.tolist()
        tr1 = tr1.item()
        fa1 = torch.as_tensor(fa[n]).item()
        mtpulse1 = torch.as_tensor(mtpulse1).item()
        flash = GradientEchoMulti(flash,
                                  affine=aff1,
                                  tr=tr1,
                                  fa=fa1,
                                  te=te1,
                                  mt=mtpulse1)
        contrasts.append(flash)

    return contrasts[0] if len(contrasts) == 1 else contrasts
Esempio n. 6
0
def preproc(data, transmit=None, receive=None, opt=None):
    """Estimate noise variance + register + compute recon space + init maps

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
    transmit : sequence[PrecomputedFieldMap], optional
    receive : sequence[PrecomputedFieldMap], optional
    opt : Options, optional

    Returns
    -------
    data : sequence[GradientEchoMulti]
    maps : ParametersMaps

    """

    opt = GREEQOptions().update(opt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate hyper parameters ---
    logmeans = []
    te = []
    tr = []
    fa = []
    mt = []
    for c, contrast in enumerate(data):
        means = []
        vars = []
        for e, echo in enumerate(contrast):
            if opt.verbose:
                print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}',
                      end='\r')
            dat = echo.fdata(**backend, rand=True, cache=False)
            sd0, sd1, mu0, mu1 = estimate_noise(dat)
            echo.mean = mu1.item()
            echo.sd = sd0.item()
            means.append(mu1)
            vars.append(sd0.square())
        means = torch.stack(means)
        vars = torch.stack(vars)
        var = (means * vars).sum() / means.sum()
        contrast.noise = var.item()

        te.append(contrast.te)
        tr.append(contrast.tr)
        fa.append(contrast.fa / 180 * core.constants.pi)
        mt.append(contrast.mt)
        logmeans.append(means.log())
    if opt.verbose:
        print('')

    print('Estimating maps from volumes:')
    for i in range(len(data)):
        print(f'    - Contrast {i:d}: ', end='')
        print(f'FA = {fa[i]*180/core.constants.pi:2.0f} deg  / ', end='')
        print(f'TR = {tr[i]*1e3:4.1f} ms / ', end='')
        print('TE = [' + ', '.join([f'{t*1e3:.1f}' for t in te[i]]) + '] ms',
              end='')
        if mt[i]:
            print(f' / MT = True', end='')
        print()

    # --- initial minifit ---
    print('Compute initial parameters')
    inter, r2s = _loglin_minifit(logmeans, te)
    pd, r1, mt = _rational_minifit(inter, tr, fa, mt)
    print(f'    - PD:    {pd.tolist():9.3g} a.u.')
    print(f'    - R1:    {r1.tolist():9.3g} 1/s')
    print(f'    - R2*:   {r2s.tolist():9.3g} 1/s')
    pd = pd.log()
    r1 = r1.log()
    r2s = r2s.log()
    if mt is not None:
        print(f'    - MT:    {100*mt.tolist():9.3g} %')
        mt = mt.log() - (1 - mt).log()

    # --- initial align ---
    transmit = core.utils.make_list(transmit or [])
    receive = core.utils.make_list(receive or [])
    if opt.preproc.register and len(data) > 1:
        print('Register volumes')
        data_reg = [(contrast.echo(0).fdata(rand=True, cache=False,
                                            **backend), contrast.affine)
                    for contrast in data]
        data_reg += [(map.magnitude.fdata(rand=True, cache=False,
                                          **backend), map.magnitude.affine)
                     for map in transmit]
        data_reg += [(map.magnitude.fdata(rand=True, cache=False,
                                          **backend), map.magnitude.affine)
                     for map in receive]
        dats, affines, _ = affine_align(data_reg, device=device)
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i + 1)
                plt.imshow(dats[i, :, dats.shape[2] // 2, :].cpu())
                plt.axis('off')
            plt.show()
        for contrast, aff in zip(data + transmit + receive, affines):
            aff, contrast.affine = core.utils.to_max_device(
                aff, contrast.affine)
            contrast.affine = torch.matmul(aff.inverse(), contrast.affine)

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape[1:] for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- allocate maps ---
    maps = GREEQParameterMaps()
    maps.pd = ParameterMap(mean_shape, fill=pd, affine=mean_affine, **backend)
    maps.r1 = ParameterMap(mean_shape, fill=r1, affine=mean_affine, **backend)
    maps.r2s = ParameterMap(mean_shape,
                            fill=r2s,
                            affine=mean_affine,
                            **backend)
    if mt is not None:
        maps.mt = ParameterMap(mean_shape,
                               fill=mt,
                               affine=mean_affine,
                               **backend)
    maps.affine = mean_affine

    # --- repeat fields if not enough ---
    if transmit:
        transmit = core.py.make_list(transmit, len(data))
    else:
        transmit = [None] * len(data)
    if receive:
        receive = core.py.make_list(receive, len(data))
    else:
        receive = [None] * len(data)

    return data, transmit, receive, maps
Esempio n. 7
0
def vfa(data, transmit=None, receive=None, opt=None, **kwopt):
    """Compute PD, R1 (and MTsat) from two GRE at two flip angles using
    rational approximations of the Ernst equations.

    Parameters
    ----------
    data : sequence[GradientEcho]
        Volumes with different contrasts (flip angle or MT pulse) but with
        the same echo time. Note that they do not need to be real echoes;
        they often are images extrapolated to TE = 0.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : VFAOptions
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'verbose': 1,                           # Verbosity: 1=print, 2=plot
         'rational': False}                      # Force rational approximation

    Returns
    -------
    pd : ParameterMap
        Proton density (potentially, with R2* bias)
    r1 : ParameterMap
        Longitudinal relaxation rate
    mt : ParameterMap, optional
        Magnetization transfer

    References
    ----------
    ..[1] Tabelow et al., "hMRI - A toolbox for quantitative MRI in
          neuroscience and clinical research", NeuroImage (2019).
          https://www.sciencedirect.com/science/article/pii/S1053811919300291
    ..[2] Helms et al., "Quantitative FLASH MRI at 3T using a rational
          approximation of the Ernst equation", MRM (2008).
          https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732
    ..[3] Helms et al., "High-resolution maps of magnetization transfer
          with inherent correction for RF inhomogeneity and T1 relaxation
          obtained from 3D FLASH MRI", MRM (2008).
          https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732

    """
    opt = VFAOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    data = core.py.make_list(data)
    if len(data) < 2:
        raise ValueError('Expected at least two input images')
    transmit = core.py.make_list(transmit or [])
    receive = core.py.make_list(receive or [])

    # --- Copy instances to avoid modifying the inputs ---
    data = list(map(lambda x: x.copy(), data))
    transmit = list(map(lambda x: x.copy(), transmit))
    receive = list(map(lambda x: x.copy(), receive))

    # --- check TEs ---
    if len(set([contrast.te for contrast in data])) > 1:
        raise ValueError('Echo times not consistent across contrasts')

    # --- register ---
    if opt.preproc.register:
        print('Register volumes')
        data_reg = [(contrast.fdata(rand=True, cache=False, **backend),
                     contrast.affine) for contrast in data]
        data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend),
                      fmap.magnitude.affine) for fmap in transmit]
        data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend),
                      fmap.magnitude.affine) for fmap in receive]
        dats, affines, _ = affine_align(data_reg, device=device, fwhm=3)
        
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i+1)
                plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu())
                plt.axis('off')
            plt.suptitle('Registered magnitude images')
            plt.show()
        del dats

        for vol, aff in zip(data + transmit + receive, affines):
            aff, vol.affine = core.utils.to_max_device(aff, vol.affine)
            vol.affine = torch.matmul(aff.inverse(), vol.affine)

    # --- repeat fields if not enough ---
    transmit = core.py.make_list(transmit or [None], len(data))
    receive = core.py.make_list(receive or [None], len(data))

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- compute PD/R1 ---
    pdt1 = [(id, contrast) for id, contrast in enumerate(data) if not contrast.mt]
    if len(pdt1) > 2:
        warnings.warn('More than two volumes could be used to compute PD+R1')
    pdt1 = pdt1[:2]
    if len(pdt1) < 2:
        raise ValueError('Not enough volumes to compute PD+R1')
    (pdw_idx, pdw_struct), (t1w_idx, t1w_struct) = pdt1
    
    if t1w_struct.te != pdw_struct.te:
        warnings.warn('Echo times not consistent across volumes')

    rational = opt .rational or t1w_struct.tr != pdw_struct.tr
    method = 'rational' if rational else 'analytical'
    print(f'Computing PD and R1 ({method}) from volumes:')
    print(f'  - '
          f'FA = {pdw_struct.fa:2.0f} deg  /  '
          f'TR = {pdw_struct.tr*1e3:4.1f} ms /  '
          f'TE = {pdw_struct.te*1e3:4.1f} ms')
    
    pdw = load_and_pull(pdw_struct, mean_affine, mean_shape, **backend)
    pdw_fa = pdw_struct.fa / 180. * core.constants.pi
    pdw_tr = pdw_struct.tr
    if receive[pdw_idx]:
        b1m = load_and_pull(receive[pdw_idx], mean_affine, mean_shape, **backend)
        unit = receive[pdw_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        pdw /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            pdw *= 100
        del b1m 
    if transmit[pdw_idx]:
        b1p = load_and_pull(transmit[pdw_idx], mean_affine, mean_shape, **backend)
        unit = transmit[pdw_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        pdw_fa = b1p * pdw_fa
        if unit in ('%', 'pct', 'p.u.'):
            pdw_fa /= 100
        del b1p
    pdw_fa = torch.as_tensor(pdw_fa, **backend)

    print(f'  - '
          f'FA = {t1w_struct.fa:2.0f} deg  /  '
          f'TR = {t1w_struct.tr*1e3:4.1f} ms /  '
          f'TE = {t1w_struct.te*1e3:4.1f} ms')
    
    t1w = load_and_pull(t1w_struct, mean_affine, mean_shape, **backend)
    t1w_fa = t1w_struct.fa / 180. * core.constants.pi
    t1w_tr = t1w_struct.tr
    if receive[t1w_idx]:
        b1m = load_and_pull(receive[t1w_idx], mean_affine, mean_shape, **backend)
        unit = receive[t1w_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        t1w /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            t1w *= 100
        del b1m
    if transmit[pdw_idx]:
        b1p = load_and_pull(transmit[t1w_idx], mean_affine, mean_shape, **backend)
        unit = transmit[t1w_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        t1w_fa = b1p * t1w_fa
        if unit in ('%', 'pct', 'p.u.'):
            t1w_fa /= 100
        del b1p
    t1w_fa = torch.as_tensor(t1w_fa, **backend)

    if rational:
        # we must use rational approximations
        r1 = 0.5 * (t1w * (t1w_fa / t1w_tr) - pdw * (pdw_fa / pdw_tr))
        r1 /= ((pdw / pdw_fa) - (t1w / t1w_fa))

        pd = (t1w * pdw) * (t1w_tr * (pdw_fa / t1w_fa) - pdw_tr * (t1w_fa / pdw_fa))
        pd /= (pdw * (pdw_tr * pdw_fa) - t1w * (t1w_tr * t1w_fa))
        del t1w_fa, pdw_fa, t1w, pdw
    else:
        # there is an analytical solution
        cosfa_t1w = t1w_fa.cos()
        sinfa_t1w = t1w_fa.sin_()
        del t1w_fa
        cosfa_pdw = pdw_fa.cos()
        sinfa_pdw = pdw_fa.sin_()
        del pdw_fa
        e1 = sinfa_pdw * t1w - sinfa_t1w * pdw
        e1 /= sinfa_pdw * cosfa_t1w * t1w - sinfa_t1w * cosfa_pdw * pdw

        pd = t1w / sinfa_t1w * (1 - cosfa_t1w * e1) / (1 - e1)
        pd += pdw / sinfa_pdw * (1 - cosfa_pdw * e1) / (1 - e1)
        pd /= 2.

        r1 = e1.log_().neg_().div_(t1w_tr)
        del t1w, pdw, cosfa_pdw, cosfa_t1w

    r1[~torch.isfinite(r1)] = 0
    pd[~torch.isfinite(pd)] = 0

    # --- compute MTsat ---
    mtw_struct = [(id, contrast) for id, contrast in enumerate(data)
                  if contrast.mt]
    if len(mtw_struct) == 0:
        return (ParameterMap(pd, affine=mean_affine, unit=None),
                ParameterMap(r1, affine=mean_affine, unit='1/s'))

    if len(mtw_struct) > 1:
        warnings.warn('More than one volume could be used to compute MTsat')
    mtw_idx, mtw_struct = mtw_struct[0]

    if mtw_struct.te != pdw_struct.te:
        warnings.warn('Echo times not consistent across volumes')

    method = 'rational' if opt.rational else 'analytical'
    print(f'Computing MTsat ({method}) from PD/R1 maps and volume:')
    print(f'  - '
          f'FA = {mtw_struct.fa:2.0f} deg  /  '
          f'TR = {mtw_struct.tr*1e3:4.1f} ms /  '
          f'TE = {mtw_struct.te*1e3:4.1f} ms')
    
    mtw = load_and_pull(mtw_struct, mean_affine, mean_shape, **backend)
    mtw_fa = mtw_struct.fa / 180. * core.constants.pi
    mtw_tr = mtw_struct.tr
    if receive[mtw_idx]:
        b1m = load_and_pull(receive[mtw_idx], mean_affine, mean_shape, **backend)
        unit = receive[mtw_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        mtw /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            mtw *= 100
        del b1m
    if transmit[mtw_idx]:
        b1p = load_and_pull(transmit[mtw_idx], mean_affine, mean_shape)
        unit = transmit[mtw_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        mtw_fa = b1p * mtw_fa
        if unit in ('%', 'pct', 'p.u.'):
            mtw_fa /= 100
        del b1p
    mtw_fa = torch.as_tensor(mtw_fa, **backend)

    if opt.rational:
        # we must use rational approximations
        mtsat = (mtw_fa * pd / mtw - 1) * r1 * mtw_tr - 0.5 * (mtw_fa ** 2)
        del mtw_fa, mtw
    else:
        # we have an analytical solution (work backward from PD/R1)
        cosfa_mtw = mtw_fa.cos()
        sinfa_mtw = mtw_fa.sin()
        del mtw_fa
        e1 = (-mtw_tr * r1).exp()
        mtsat = mtw / (cosfa_mtw * e1 * mtw + pd * sinfa_mtw * (1 - e1))
        del mtw, cosfa_mtw, sinfa_mtw
        mtsat = 1 - mtsat

    mtsat *= 100
    mtsat[~torch.isfinite(mtsat)] = 0

    return (ParameterMap(pd, affine=mean_affine, unit=None),
            ParameterMap(r1, affine=mean_affine, unit='1/s'),
            ParameterMap(mtsat, affine=mean_affine, unit='%'))
Esempio n. 8
0
def _prepare(data, dist, opt):

    # --- options ------------------------------------------------------
    # we deepcopy all options so that we can overwrite/simplify them in place
    opt = ESTATICSOptions().update(opt).cleanup_()
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- be polite ----------------------------------------------------
    if len(data) > 1:
        pstr = f'Fitting a (shared) exponential decay model with {len(data)} contrasts.'
    else:
        pstr = f'Fitting an exponential decay model.'
    print(pstr)
    print('Echo times:')
    for i, contrast in enumerate(data):
        print(f'    - contrast {i:2d}: [' +
              ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms')

    # --- estimate noise / register / initialize maps ------------------
    data, maps, dist = preproc(data, dist, opt)
    nb_contrasts = len(maps) - 1

    if opt.distortion.enable:
        print('Readout directions:')
        for i, contrast in enumerate(data):
            layout = spatial.affine_to_layout(contrast.affine)
            layout = spatial.volume_layout_to_name(layout)
            readout = layout[contrast.readout]
            readout = ('left-right' if 'L' in readout or 'R' in readout else
                       'infra-supra' if 'I' in readout or 'S' in readout else
                       'antero-posterior'
                       if 'A' in readout or 'P' in readout else 'unknown')
            print(f'    - contrast {i:2d}: {readout}')

    # --- prepare regularization factor --------------------------------
    # 1. Parameter maps regularization
    #   -> we want lam = [*lam_intercepts, lam_decay]
    *lam, lam_decay = opt.regularization.factor
    lam = core.py.make_list(lam, nb_contrasts)
    lam.append(lam_decay)
    if not any(lam):
        opt.regularization.norm = ''
    opt.regularization.factor = lam
    # 2. Distortion fields regularization
    lam_dist = dict(factor=opt.distortion.factor,
                    absolute=opt.distortion.absolute,
                    membrane=opt.distortion.membrane,
                    bending=opt.distortion.bending)
    opt.distortion.factor = lam_dist

    # --- initialize weights (RLS) -------------------------------------
    mean_shape = maps.decay.volume.shape
    rls = None
    if opt.regularization.norm.endswith('tv'):
        rls_shape = mean_shape
        if opt.regularization.norm == 'tv':
            rls_shape = (len(maps), *rls_shape)
        rls = ParameterMap(rls_shape, fill=1, **backend).volume

    if opt.regularization.norm:
        print('Regularization:')
        print(f'    - type:           {opt.regularization.norm.upper()}')
        print(f'    - log intercepts: [' +
              ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']')
        print(f'    - decay:          {lam[-1]:.3g}')
    else:
        print('Without regularization')

    if opt.distortion.enable:
        print('Distortion correction:')
        print(f'    - model:          {opt.distortion.model.lower()}')
        print(
            f'    - absolute:       {opt.distortion.absolute * opt.distortion.factor["factor"]}'
        )
        print(
            f'    - membrane:       {opt.distortion.membrane * opt.distortion.factor["factor"]}'
        )
        print(
            f'    - bending:        {opt.distortion.bending * opt.distortion.factor["factor"]}'
        )

    else:
        print('Without distortion correction')

    # --- initialize nb of iterations ----------------------------------
    if not opt.regularization.norm.endswith('tv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_prm *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.regularization.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - Param iterations: {opt.optim.max_iter_prm}')
    if opt.distortion.enable:
        print(f'    - Dist iterations:  {opt.optim.max_iter_dist}')
    print(f'    - FMG cycles:       2')
    print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
          f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    # ------------------------------------------------------------------
    #                     MAIN OPTIMIZATION LOOP
    # ------------------------------------------------------------------

    if opt.verbose:
        pstr = f'{"rls":^3s} | {"gn":^3s} | {"step":^4s} | '
        pstr += f'{"fit":^12s} + {"reg":^12s} + {"rls":^12s} '
        if opt.distortion.enable:
            pstr += f'+ {"dist":^12s} '
        pstr += f'= {"crit":^12s}'
        if opt.optim.nb_levels > 1:
            pstr = f'{"lvl":3s} | ' + pstr
        print('\n' + pstr)
        print('-' * len(pstr))

    return data, maps, dist, opt, rls
Esempio n. 9
0
def nonlin(data, opt=None):
    """Fit the ESTATICS model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map

    """

    opt = ESTATICSOptions().update(opt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- be polite ---
    print(
        f'Fitting a (multi) exponential decay model with {len(data)} contrasts. Echo times:'
    )
    for i, contrast in enumerate(data):
        print(f'    - contrast {i:2d}: [' +
              ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms')

    # --- estimate noise / register / initialize maps ---
    data, maps = preproc(data, opt)
    print(maps.affine)
    vx = spatial.voxel_size(maps.affine)

    # --- prepare regularization factor ---
    lam = opt.regularization.factor
    lam = core.utils.make_list(lam)
    if len(lam) > 1:
        *lam, lam_decay = lam
    else:
        lam_decay = lam[0]
    lam = core.utils.make_list(lam, len(maps) - 1)
    lam.append(lam_decay)

    # --- initialize weights (RLS) ---
    if (not opt.regularization.norm
            or opt.regularization.norm.lower() == 'none'
            or all(l == 0 for l in lam)):
        opt.regularization.norm = ''
    opt.regularization.norm = opt.regularization.norm.lower()
    mean_shape = maps.decay.volume.shape
    rls = None
    sumrls = 0
    if opt.regularization.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.regularization.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = ParameterMap(rls_shape, fill=1, **backend).volume
        sumrls = 0.5 * rls.sum(dtype=torch.double)

    if opt.regularization.norm:
        print(f'With {opt.regularization.norm.upper()} regularization:')
        print('    - log intercepts: [' +
              ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']')
        print(f'    - decay:          {lam[-1]:.3g}')
    else:
        print('Without regularization:')

    # --- compute derivatives ---
    grad = torch.empty((len(data) + 1, ) + mean_shape, **backend)
    hess = torch.empty((len(data) * 2 + 1, ) + mean_shape, **backend)

    if opt.regularization.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1

    if opt.verbose:
        print('{:^3s} | {:^3s} | {:^12s} + {:^12s} + {:^12s} = {:^12s}'.format(
            'rls', 'gn', 'fit', 'reg', 'rls', 'crit'))

    ll_rls = []
    ll_max = core.constants.ninf

    for n_iter_rls in range(opt.optim.max_iter_rls):

        multi_rls = rls if opt.regularization.norm == 'tv' \
                    else [rls] * len(maps)

        # --- Gauss Newton loop ---
        ll_gn = []
        for n_iter_gn in range(opt.optim.max_iter_gn):
            decay = maps.decay
            crit = 0
            grad.zero_()
            hess.zero_()

            # --- loop over contrasts ---
            for i, (contrast,
                    intercept) in enumerate(zip(data, maps.intercepts)):
                # compute gradient
                crit1, g1, h1 = _nonlin_gradient(contrast, intercept, decay,
                                                 opt)

                # increment
                gind = [i, -1]
                grad[gind, ...] += g1
                hind = [2 * i, -1, 2 * i + 1]
                hess[hind, ...] += h1
                crit += crit1

            # --- regularization ---
            reg = 0.
            if opt.regularization.norm:
                for i, (map, weight, l) in enumerate(zip(maps, multi_rls,
                                                         lam)):
                    if not l:
                        continue
                    reg1, g1 = _nonlin_reg(map.volume, vx, weight, l)
                    reg += reg1
                    grad[i] += g1

            # --- gauss-newton ---
            if not hess.isfinite().all():
                print('WARNING: NaNs in hess (??)')
            if opt.regularization.norm:
                hess = hessian_loaddiag(hess, 1e-6, 1e-8)
                deltas = _nonlin_solve(hess, grad, multi_rls, lam, vx, opt)
            else:
                hess = hessian_loaddiag(hess, 1e-6, 1e-8)
                deltas = hessian_solve(hess, grad)
            if not deltas.isfinite().all():
                print('WARNING: NaNs in delta (non stable Hessian)')

            for map, delta in zip(maps, deltas):
                map.volume -= delta
                if map.min is not None or map.max is not None:
                    map.volume.clamp_(map.min, map.max)

            # --- Compute gain ---
            ll = crit + reg + sumrls
            ll_max = max(ll_max, ll)
            ll_prev = ll_gn[-1] if ll_gn else ll_max
            gain = (ll_prev - ll) / (ll_max - ll_prev)
            ll_gn.append(ll)
            if opt.verbose:
                print(
                    '{:3d} | {:3d} | {:12.6g} + {:12.6g} + {:12.6g} = {:12.6g} | gain = {:7.2g}'
                    .format(n_iter_rls, n_iter_gn, crit, reg, sumrls,
                            crit + reg + sumrls, gain))
            if gain < opt.optim.tolerance_gn:
                break

        # --- Update RLS weights ---
        if opt.regularization.norm in ('tv', 'jtv'):
            rls = _nonlin_rls(maps, lam, opt.regularization.norm)
            sumrls = 0.5 * rls.sum(dtype=torch.double)
            eps = core.constants.eps(rls.dtype)
            rls = rls.clamp_min_(eps).reciprocal_()

            # --- Compute gain ---
            # (we are late by one full RLS iteration when computing the
            #  gain but we save some computations)
            ll = ll_gn[-1]
            ll_prev = ll_rls[-1][-1] if ll_rls else ll_max
            ll_rls.append(ll_gn)
            gain = (ll_prev - ll) / (ll_max - ll_prev)
            if gain < opt.optim.tolerance_rls:
                print(f'Converged ({gain:7.2g})')
                break

    # --- Prepare output ---
    return postproc(maps, data)
Esempio n. 10
0
 def decay(self):
     if self.volume is None:
         return None
     return ParameterMap(self.volume[-1], affine=self.affine)