Пример #1
0
 def __init__(self, *args, **kwargs):
     self._greeq_volume = None
     self._greeq_affine = None
     super().__init__(*args, **kwargs)
     if len(self) not in (3, 4):
         raise ValueError('Exected 3 or 4 maps')
     self.pd = ParameterMap(self.volume[0], affine=self.affine)
     self.r1 = ParameterMap(self.volume[1], affine=self.affine)
     self.r2s = ParameterMap(self.volume[2], affine=self.affine)
     if len(self) == 4:
         self.mt = ParameterMap(self.volume[3], affine=self.affine)
Пример #2
0
    def fit(self, data):

        # --- be polite ------------------------------------------------
        if self.opt.verbose > 0:
            print(f'Fitting a (multi) exponential decay model with '
                  f'{len(data)} contrasts. Echo times:')
            for i, contrast in enumerate(data):
                print(f'    - contrast {i:2d}: [' +
                      ', '.join([f'{te * 1e3:.1f}'
                                 for te in contrast.te]) + '] ms')

        # --- estimate noise / register / initialize maps --------------
        self.data, self.maps, self.dist = preproc(data, self.opt)
        self.affine0 = self.maps.affine
        self.shape0 = self.maps.decay.shape

        # --- prepare regularization factor ----------------------------
        # -> we want lam = [*lam_intercepts, lam_decay]
        *lam, lam_decay = self.opt.regularization.factor
        lam = core.py.ensure_list(lam, self.nb_contrasts)
        lam.append(lam_decay)
        if not any(lam):
            self.opt.regularization.norm = ''
        self.lam = lam

        # --- initialize weights (RLS) ---------------------------------
        self.rls = None
        if self.opt.regularization.norm.endswith('tv'):
            rls_shape = self.shape
            if self.opt.regularization.norm == 'tv':
                rls_shape = (len(self.maps), *rls_shape)
            self.rls = ParameterMap(rls_shape, fill=1, **self.backend).volume

        # --- initialize nb of iterations ------------------------------
        if not self.opt.regularization.norm.endswith('tv'):
            # no reweighting -> do more gauss-newton updates instead
            self.opt.optim.max_iter_gn *= self.opt.optim.max_iter_rls
            self.opt.optim.max_iter_rls = 1

        # --- be polite (bis) ------------------------------------------
        self.print_opt()
        self.print_header()

        # --- main loop ------------------------------------------------
        self.loop()

        # --- prepare output -----------------------------------------------
        out = postproc(self.maps, self.data)
        if self.opt.distortion.enable:
            out = (*out, self.dist)
        return out
Пример #3
0
 def intercepts(self):
     if self.volume is None:
         return []
     return [
         ParameterMap(self.volume[i], affine=self.affine)
         for i in range(len(self) - 1)
     ]
Пример #4
0
def preproc(data, opt):
    """Estimate noise variance + register + compute recon space + init maps

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
    opt : Options

    Returns
    -------
    data : sequence[GradientEchoMulti]
    maps : ParametersMaps

    """

    if opt is None:
        opt = ESTATICSOptions()

    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate hyper parameters ---
    logmeans = []
    te = []
    for c, contrast in enumerate(data):
        means = []
        vars = []
        for e, echo in enumerate(contrast):
            if opt.verbose:
                print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}', end='\r')
            dat = echo.fdata(**backend, rand=True, cache=False)
            sd0, sd1, mu0, mu1 = estimate_noise(dat)
            echo.mean = mu1.item()
            echo.sd = sd0.item()
            means.append(mu1)
            vars.append(sd0.square())
        means = torch.stack(means)
        vars = torch.stack(vars)
        var = (means*vars).sum() / means.sum()
        contrast.noise = var.item()

        te.append(contrast.te)
        logmeans.append(means.log())
    if opt.verbose:
        print('')

    # --- initial minifit ---
    print('Compute initial parameters')
    inter, decay = _loglin_minifit(logmeans, te)
    print('    - log intercepts: [' + ', '.join([f'{i:.1f}' for i in inter.tolist()]) + ']')
    print(f'    - decay:          {decay.tolist():.3g}')

    # --- initial align ---
    if opt.preproc.register and len(data) > 1:
        print('Register volumes')
        data_reg = [(contrast.echo(0).fdata(rand=True, cache=False, **backend),
                     contrast.affine) for contrast in data]
        dats, affines, _ = affine_align(data_reg, device=device)
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i+1)
                plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu())
            plt.show()
        for contrast, aff in zip(data, affines):
            aff, contrast.affine = core.utils.to_max_device(aff, contrast.affine)
            contrast.affine = torch.matmul(aff.inverse(), contrast.affine)

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape[1:] for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- allocate maps ---
    maps = ESTATICSParameterMaps()
    maps.intercepts = [ParameterMap(mean_shape, fill=inter[c], affine=mean_affine, **backend)
                       for c in range(len(data))]
    maps.decay = ParameterMap(mean_shape, fill=decay, affine=mean_affine, min=0, **backend)
    maps.affine = mean_affine

    return data, maps
Пример #5
0
def gre(pd,
        r1,
        r2s=None,
        mt=None,
        transmit=None,
        receive=None,
        gfactor=None,
        te=0,
        tr=25e-3,
        fa=20,
        mtpulse=False,
        sigma=None,
        noise='rician',
        affine=None,
        shape=None,
        device=None):
    """Simulate data generated by a Gradient-Echo (FLASH) sequence.

    Tissue parameters
    -----------------
    pd : ParameterMap or tensor_like
        Proton density
    r1 : ParameterMap or tensor_like
        Longitudinal relaxation rate, in 1/sec
    r2s : ParameterMap, optional
        Transverse relaxation rate, in 1/sec. Mandatory if any `te > 0`.
    mt : ParameterMap, optional
        MTsat. Mandatory if any `mtpulse == True`.

    Fields
    ------
    transmit : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        Transmit B1 field
    receive : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        Receive B1 field
    gfactor : (N-sequence of) PrecomputedFieldMap or tensor_like, optional
        G-factor map.
        If provided and `sigma` is not `None`, the g-factor map is used
        to sample non-stationary noise.

    Sequence parameters
    -------------------
    te : ((N-sequence of) M-sequence of) float, default=0
        Echo time, in sec
    tr : (N-sequence of) float default=2.5e-3
        Repetition time, in sec
    fa : (N-sequence of) float, default=20
        Flip angle, in deg
    mtpulse : (N-sequence of) bool, default=False
        Presence of an off-resonance pulse

    Noise
    -----
    sigma : (N-sequence of) float, optional
        Standard-deviation of the sampled noise (no sampling if `None`)
    noise : {'rician', 'gaussian'}, default='rician'
        Noise distribution

    Space
    -----
    affine : ([N], 4, 4) tensor, optional
        Orientation matrix of the simulation space
    shape : (N-sequence of) sequence[int], default=pd.shape
        Shape of the simulation space

    Returns
    -------
    sim : (N-sequence of) GradientEchoMulti
        Simulated series of multi-echo GRE images

    """

    # 1) Find out the number of contrasts requested
    te = make_list(te)
    if any(map(lambda x: isinstance(x, (list, tuple)), te)):
        te = [make_list(t) for t in te]
    else:
        te = [te]
    tr = make_list(tr)
    fa = make_list(fa)
    mtpulse = make_list(mtpulse)
    mtpulse = [bool(p) for p in mtpulse]
    sigma = make_list(sigma)
    transmit = make_list(transmit or [])
    receive = make_list(receive or [])
    gfactor = make_list(gfactor or [])
    shape = make_list(shape)
    if any(map(lambda x: isinstance(x, (list, tuple)), shape)):
        shape = [make_list(s) for s in shape]
    else:
        shape = [shape]
    if torch.is_tensor(affine):
        affine = [affine] if affine.dim() == 2 else affine.unbind(0)
    else:
        affine = make_list(affine)
    nb_contrasts = max(len(te), len(tr), len(fa), len(mtpulse), len(sigma),
                       len(transmit), len(receive), len(gfactor), len(shape),
                       len(affine))

    # 2) Pad all lists up to `nb_contrasts`
    te = make_list(te, nb_contrasts)
    tr = make_list(tr, nb_contrasts)
    fa = make_list(fa, nb_contrasts)
    mtpulse = make_list(mtpulse, nb_contrasts)
    sigma = make_list(sigma, nb_contrasts)
    transmit = make_list(transmit or [None], nb_contrasts)
    receive = make_list(receive or [None], nb_contrasts)
    gfactor = make_list(gfactor or [None], nb_contrasts)
    shape = make_list(shape, nb_contrasts)
    affine = make_list(affine, nb_contrasts)

    # 3) ensure parameters are `ParameterMap`s
    has_r2s = r2s is not None
    has_mt = mt is not None
    if not isinstance(pd, ParameterMap):
        pd = ParameterMap(pd)
    if not isinstance(r1, ParameterMap):
        r1 = ParameterMap(r1)
    if has_r2s and not isinstance(r2s, ParameterMap):
        r2s = ParameterMap(r2s)
    if has_mt and not isinstance(mt, ParameterMap):
        mt = ParameterMap(mt)
        mt.unit = '%'

    # 4) ensure all fields are `PrecomputedFieldMap`s
    for n in range(nb_contrasts):
        if (transmit[n] is not None
                and not isinstance(transmit[n], PrecomputedFieldMap)):
            transmit[n] = PrecomputedFieldMap(transmit[n])
        if (receive[n] is not None
                and not isinstance(receive[n], PrecomputedFieldMap)):
            receive[n] = PrecomputedFieldMap(receive[n])
        if (gfactor[n] is not None
                and not isinstance(gfactor[n], PrecomputedFieldMap)):
            gfactor[n] = PrecomputedFieldMap(gfactor[n])

    # 5) choose backend
    all_var = [te, tr, fa, mtpulse, sigma, affine]
    all_var += [
        f.volume for f in transmit
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [
        f.volume for f in receive
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [
        f.volume for f in gfactor
        if f is not None and torch.is_tensor(f.volume)
    ]
    all_var += [pd.volume] if torch.is_tensor(pd.volume) else []
    all_var += [r1.volume] if torch.is_tensor(r1.volume) else []
    all_var += [r2s.volume
                ] if r2s is not None and torch.is_tensor(r2s.volume) else []
    all_var += [mt.volume
                ] if mt is not None and torch.is_tensor(mt.volume) else []
    backend = core.utils.max_backend(*all_var)
    if device:
        backend['device'] = device

    # 6) prepare parameter maps
    prm = stack_maps(pd.fdata(**backend), r1.fdata(**backend),
                     r2s.fdata(**backend) if r2s is not None else None,
                     mt.fdata(**backend) if mt is not None else None)
    fpd, fr1, fr2s, fmt = unstack_maps(prm, has_r2s, has_mt)
    if has_mt and mt.unit in ('%', 'pct', 'p.u.'):
        fmt /= 100.
    if any(aff is not None for aff in affine):
        logprm = stack_maps(
            safelog(fpd), safelog(fr1),
            safelog(fr2s) if r2s is not None else None,
            safelog(fmt) + safelog(1 - fmt) if mt is not None else None)

    # 7) generate noise-free signal
    contrasts = []
    for n in range(nb_contrasts):

        shape1 = shape[n]
        if shape1 is None:
            shape1 = pd.shape
        aff1 = affine[n]
        if aff1 is not None:
            aff1 = aff1.to(**backend)
        te1 = torch.as_tensor(te[n], **backend)
        tr1 = torch.as_tensor(tr[n], **backend)
        fa1 = torch.as_tensor(fa[n], **backend) / 180. * core.constants.pi
        sigma1 = torch.as_tensor(sigma[n], **backend) if sigma[n] else None
        mtpulse1 = mtpulse[n]
        transmit1 = transmit[n]
        receive1 = receive[n]
        gfactor1 = gfactor[n]

        if aff1 is not None:
            mat = core.linalg.lmdiv(pd.affine.to(**backend), aff1)
            grid = smart_grid(mat, shape1, pd.shape, force=True)
            prm1 = smart_pull(logprm, grid)
            inplace = grid is not None
            f1pd, f1r1, f1r2s, f1mt = unstack_maps(prm1, has_r2s, has_mt)
            f1pd, f1r1, f1r2s, f1mt = exp_maps(f1pd,
                                               f1r1,
                                               f1r2s,
                                               f1mt,
                                               inplace=inplace)
        else:
            f1pd, f1r1, f1r2s, f1mt = (fpd, fr1, fr2s, fmt)
            # clone it so that we can work in-place later
            f1pd = f1pd.clone()
            f1r1 = f1r1.clone()
            if f1r2s is not None:
                f1r2s = f1r2s.clone()
            if f1mt is not None:
                f1mt = f1mt.clone()

        if transmit1 is not None:
            unit = transmit1.unit
            taff1 = transmit1.affine.to(**backend)
            transmit1 = transmit1.fdata(**backend)
            if unit in ('%', 'pct', 'p.u'):
                transmit1 = transmit1 / 100.
            if aff1 is not None:
                mat = core.linalg.lmdiv(taff1, aff1)
                grid = smart_grid(mat, shape1, transmit1.shape)
                transmit1 = smart_pull(transmit1[None], grid)[0]
                del grid
            fa1 = fa1 * transmit1
            del transmit1

        if receive1 is not None:
            unit = receive1.unit
            raff1 = receive1.affine.to(**backend)
            receive1 = receive1.fdata(**backend)
            if unit in ('%', 'pct', 'p.u'):
                receive1 = receive1 / 100.
            if aff1 is not None:
                mat = core.linalg.lmdiv(raff1, aff1)
                grid = smart_grid(mat, shape1, receive1.shape)
                receive1 = smart_pull(receive1[None], grid)[0]
                del grid
            f1pd = f1pd * receive1
            del receive1

        # generate signal
        flash = f1pd
        flash *= fa1.sin()
        cosfa = fa1.cos_()
        e1 = f1r1
        e1 *= -tr1
        e1 = e1.exp_()
        flash *= (1 - e1)
        if mtpulse1:
            if not has_mt:
                raise ValueError('Cannot simulate an MT pulse: '
                                 'an MT must be provided.')
            omt = f1mt.neg_()
            omt += 1
            flash *= omt
            flash /= (1 - cosfa * omt * e1)
            del omt
        else:
            flash /= (1 - cosfa * e1)
        del e1, cosfa, fa1, f1r1, f1mt

        # multiply with r2*
        if any(t > 0 for t in te1):
            if not has_r2s:
                raise ValueError('Cannot simulate an R2* decay: '
                                 'an R2* must be provided.')
            te1 = te1.reshape([-1] + [1] * f1r2s.dim())
            flash = flash * (-te1 * f1r2s).exp_()
            del f1r2s

        # sample noise
        if sigma1:
            if gfactor1 is not None:
                gfactor1 = gfactor1.fdata(**backend)
                if gfactor1.unit in ('%', 'pct', 'p.u'):
                    gfactor1 = gfactor1 / 100.
                if aff1 is not None:
                    mat = core.linalg.lmdiv(gfactor1.affine.to(**backend),
                                            aff1)
                    grid = smart_grid(mat, shape1, gfactor1.shape)
                    gfactor1 = smart_pull(gfactor1[None], grid)[0]
                    del grid
                sigma1 = sigma1 * gfactor1
                del gfactor1

            noise_shape = flash.shape
            if noise == 'rician':
                noise_shape = (2, ) + noise_shape
            sample = torch.randn(noise_shape, **backend)
            sample *= sigma1
            del sigma1
            if noise == 'rician':
                sample = sample.square_().sum(dim=0)
                flash = flash.square_().add_(sample).sqrt_()
            else:
                flash += sample
            del sample

        te1 = te1.tolist()
        tr1 = tr1.item()
        fa1 = torch.as_tensor(fa[n]).item()
        mtpulse1 = torch.as_tensor(mtpulse1).item()
        flash = GradientEchoMulti(flash,
                                  affine=aff1,
                                  tr=tr1,
                                  fa=fa1,
                                  te=te1,
                                  mt=mtpulse1)
        contrasts.append(flash)

    return contrasts[0] if len(contrasts) == 1 else contrasts
Пример #6
0
class _ESTATICS_nonlin:

    IMAGE_BOUND = 'dft'
    DIST_BOUND = 'dct2'

    def __init__(self, opt):
        self.opt = ESTATICSOptions().update(opt).cleanup_()
        self.data = None
        self.maps = None
        self.dist = None
        self.rls = None
        self.lam = None
        self.shape0 = None
        self.affine0 = None
        self.lam_scale = 1
        self.last_step = ''

    @property
    def numel(self):
        return sum(core.py.prod(dat.shape) for dat in self.data)

    @property
    def shape(self):
        return self.maps.decay.shape

    @property
    def affine(self):
        return self.maps.affine

    @property
    def voxel_size(self):
        return spatial.voxel_size(self.affine)

    @property
    def nb_contrasts(self):
        return len(self.maps) - 1

    @property
    def lam_dist(self):
        return dict(factor=self.opt.distortion.factor,
                    absolute=self.opt.distortion.absolute,
                    membrane=self.opt.distortion.membrane,
                    bending=self.opt.distortion.bending)

    @property
    def backend(self):
        return dict(dtype=self.opt.backend.dtype,
                    device=self.opt.backend.device)

    def iter_rls(self):
        if self.rls is None:
            for _ in range(self.nb_contrasts + 1):
                yield None
        elif self.rls.dim() == 3:
            for _ in range(self.nb_contrasts + 1):
                yield self.rls
        else:
            assert self.rls.dim() == 4
            for rls1 in self.rls:
                yield rls1

    def fit(self, data):

        # --- be polite ------------------------------------------------
        if self.opt.verbose > 0:
            print(f'Fitting a (multi) exponential decay model with '
                  f'{len(data)} contrasts. Echo times:')
            for i, contrast in enumerate(data):
                print(f'    - contrast {i:2d}: [' +
                      ', '.join([f'{te * 1e3:.1f}'
                                 for te in contrast.te]) + '] ms')

        # --- estimate noise / register / initialize maps --------------
        self.data, self.maps, self.dist = preproc(data, self.opt)
        self.affine0 = self.maps.affine
        self.shape0 = self.maps.decay.shape

        # --- prepare regularization factor ----------------------------
        # -> we want lam = [*lam_intercepts, lam_decay]
        *lam, lam_decay = self.opt.regularization.factor
        lam = core.py.ensure_list(lam, self.nb_contrasts)
        lam.append(lam_decay)
        if not any(lam):
            self.opt.regularization.norm = ''
        self.lam = lam

        # --- initialize weights (RLS) ---------------------------------
        self.rls = None
        if self.opt.regularization.norm.endswith('tv'):
            rls_shape = self.shape
            if self.opt.regularization.norm == 'tv':
                rls_shape = (len(self.maps), *rls_shape)
            self.rls = ParameterMap(rls_shape, fill=1, **self.backend).volume

        # --- initialize nb of iterations ------------------------------
        if not self.opt.regularization.norm.endswith('tv'):
            # no reweighting -> do more gauss-newton updates instead
            self.opt.optim.max_iter_gn *= self.opt.optim.max_iter_rls
            self.opt.optim.max_iter_rls = 1

        # --- be polite (bis) ------------------------------------------
        self.print_opt()
        self.print_header()

        # --- main loop ------------------------------------------------
        self.loop()

        # --- prepare output -----------------------------------------------
        out = postproc(self.maps, self.data)
        if self.opt.distortion.enable:
            out = (*out, self.dist)
        return out

    def print_opt(self):
        if self.opt.verbose <= 1:
            return
        if self.opt.regularization.norm:
            print('Regularization:')
            print(
                f'    - type:           {self.opt.regularization.norm.upper()}'
            )
            print(f'    - log intercepts: [' +
                  ', '.join([f'{i:.3g}' for i in self.lam[:-1]]) + ']')
            print(f'    - decay:          {self.lam[-1]:.3g}')
        else:
            print('Without regularization')

        if self.opt.distortion.enable:
            print('Distortion correction:')
            print(f'    - model:          {self.opt.distortion.model.lower()}')
            print(
                f'    - absolute:       {self.opt.distortion.absolute * self.opt.distortion.factor}'
            )
            print(
                f'    - membrane:       {self.opt.distortion.membrane * self.opt.distortion.factor}'
            )
            print(
                f'    - bending:        {self.opt.distortion.bending * self.opt.distortion.factor}'
            )
            print(
                f'    - te_scaling:     {self.opt.distortion.te_scaling or "no"}'
            )

        else:
            print('Without distortion correction')

        print('Optimization:')
        if self.opt.regularization.norm.endswith('tv'):
            print(f'    - IRLS iterations: {self.opt.optim.max_iter_rls}'
                  f' (tolerance: {self.opt.optim.tolerance_rls})')
        print(f'    - GN iterations:   {self.opt.optim.max_iter_gn}'
              f' (tolerance: {self.opt.optim.tolerance_gn})')
        print(f'    - FMG cycles:      2')
        print(f'    - CG iterations:   {self.opt.optim.max_iter_cg}'
              f' (tolerance: {self.opt.optim.tolerance_cg})')

    def loop(self):
        """Nested optimization loops"""

        rls = self.rls.reciprocal().sum() if self.rls is not None else 0
        self.nll = dict(obs=float('inf'),
                        obs_prev=float('inf'),
                        reg=0,
                        reg_prev=0,
                        vreg=0,
                        vreg_prev=0,
                        rls=rls,
                        rls_prev=rls,
                        all=[])
        nll = float('inf')

        nll_scl = self.numel * len(self.data)

        # --- Multi-Resolution loop ------------------------------------
        for self.level in range(self.opt.optim.nb_levels, 0, -1):
            if self.opt.optim.nb_levels > 1:
                self.resize()

            # --- RLS loop ---------------------------------------------
            #   > max_iter_rls == 1 if regularization is not (J)TV
            for self.n_iter_rls in range(1, self.opt.optim.max_iter_rls + 1):

                # --- Gauss-Newton (prm + dist) ------------------------
                for self.n_iter_gn in range(1, self.opt.optim.max_iter_gn + 1):
                    nll0_gn = nll
                    self.n_iter_dist = 0
                    self.n_iter_prm = 0

                    # --- Gauss-Newton (prm) ---------------------------
                    max_iter_prm = self.opt.optim.max_iter_prm
                    if self.n_iter_gn == 1 and self.n_iter_rls == 1:
                        max_iter_prm = max_iter_prm * 2
                    for self.n_iter_prm in range(1, max_iter_prm + 1):
                        nll0_prm = nll
                        nll = self.update_prm()
                        if (self.n_iter_gn > 1 and (nll0_prm - nll) <
                                self.opt.optim.tolerance * nll_scl):
                            break

                        # ----------------------------------------------
                        # this is where we should check for RLS (== global) gain
                        if (self.n_iter_prm == 0 and self.n_iter_gn == 0
                                and self.n_iter_rls > 1 and (nll0_rls - nll) <
                                self.opt.optim.tolerance * nll_scl * 2):
                            return
                        # ----------------------------------------------

                    # --- Gauss-Newton (dist) --------------------------
                    for self.n_iter_dist in range(
                            1, self.opt.optim.max_iter_dist + 1):
                        nll0_dist = nll
                        nll = self.update_dist()
                        if (self.n_iter_dist > 1 and (nll0_dist - nll) <
                                self.opt.optim.tolerance * nll_scl):
                            break

                    if (self.n_iter_gn > 1 and
                        (nll0_gn - nll) < self.opt.optim.tolerance * nll_scl):
                        break

                nll0_rls = nll
                nll = self.update_rls()

    # ------------------------------------------------------------------
    #                       UPDATE PARAMETERS MAPS
    # ------------------------------------------------------------------

    def momentum_prm(self, dat):
        """Momentum of the parameter maps"""
        return spatial.regulariser(dat,
                                   weights=self.rls,
                                   dim=3,
                                   **self.lam_prm,
                                   voxel_size=self.voxel_size)

    def solve_prm(self, hess, grad):
        """Solve Newton step"""
        hess = check_nans_(hess, warn='hessian')
        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
        delta = spatial.solve_field_fmg(hess,
                                        grad,
                                        self.rls,
                                        **self.lam_prm,
                                        voxel_size=self.voxel_size)
        delta = check_nans_(delta, warn='delta')
        return delta

    def update_prm(self):
        """Update parameter maps (log-intercept and decay)"""

        nmaps = len(self.data)
        grad = torch.zeros((nmaps + 1, *self.shape), **self.backend)
        hess = torch.zeros((nmaps * 2 + 1, *self.shape), **self.backend)

        # --- loop over contrasts --------------------------------------
        iterator = zip(self.data, self.maps.intercepts, self.dist)
        nll = 0
        for i, (contrast, intercept, distortion) in enumerate(iterator):
            # compute gradient
            nll1, g1, h1 = self.derivatives_prm(contrast, distortion,
                                                intercept, self.maps.decay)
            # increment
            gind = [i, -1]
            grad[gind] += g1
            hind = [i, len(grad) - 1, len(grad) + i]
            hess[hind] += h1
            nll += nll1
            del g1, h1

        # --- regularization ------------------------------------------
        reg = 0
        if self.opt.regularization.norm:
            g1 = self.momentum_prm(self.data.volume)
            reg = 0.5 * dot(g1, self.data.volume)
            grad += g1
            del g1

        # --- gauss-newton ---------------------------------------------
        # Computing the GN step involves solving H\g
        deltas = self.solve_prm(hess, grad)

        # No need for a line search (hopefully)
        for map, delta in zip(self.maps, deltas):
            map.volume -= delta
            if map.min is not None or map.max is not None:
                map.volume.clamp_(map.min, map.max)

        # --- track general improvement --------------------------------
        self.nll['obs_prev'] = self.nll['obs']
        self.nll['obs'] = nll
        self.nll['reg_prev'] = self.nll['reg']
        self.nll['reg'] = reg
        nll = self.print_nll()
        self.last_step = 'prm'
        return nll

    # ------------------------------------------------------------------
    #                       UPDATE DISTORTION MAPS
    # ------------------------------------------------------------------

    def momentum_dist(self, dat, vx, readout):
        """Momentum of the distortion maps"""
        lam = dict(self.lam_dist)
        lam['factor'] = lam['factor'] * (vx[readout]**2)
        return spatial.regulariser(dat[None],
                                   **self.lam_dist,
                                   dim=3,
                                   bound=self.DIST_BOUND,
                                   voxel_size=vx)[0]

    def solve_dist(self, hess, grad, vx, readout):
        """Solve Newton step"""
        hess = check_nans_(hess, warn='hessian')
        hess = hessian_loaddiag_(hess, 1e-6, 1e-8)
        lam = dict(self.lam_dist)
        lam['factor'] = lam['factor'] * (vx[readout]**2)
        delta = spatial.solve_field_fmg(hess,
                                        grad,
                                        **self.lam_dist,
                                        dim=3,
                                        bound=self.DIST_BOUND,
                                        voxel_size=vx)
        delta = check_nans_(delta, warn='delta')
        return delta

    def update_dist(self):
        """Update distortions"""

        nll = 0
        reg = 0
        # --- loop over contrasts --------------------------------------
        iterator = zip(self.data, self.maps.intercepts, self.dist)
        for i, (contrast, intercept, distortion) in enumerate(iterator):

            momentum = lambda dat: self.momentum_dist(
                dat, distortion.voxel_size, contrast.readout)
            solve = lambda h, g: self.solve_dist(h, g, distortion.voxel_size,
                                                 contrast.readout)
            vol = distortion.volume

            # --- likelihood -------------------------------------------
            nll1, g, h = self.derivatives_dist(contrast, distortion, intercept,
                                               self.maps.decay)

            # --- regularization ---------------------------------------
            g1 = momentum(vol)
            reg1 = 0.5 * dot(g1, vol)
            g += g1
            del g1

            # --- gauss-newton -----------------------------------------
            delta = solve(h, g)
            del g, h

            # --- line search ------------------------------------------
            armijo, armijo_prev = 1, 0
            dd = momentum(delta)
            dv = dot(dd, vol)
            dd = dot(dd, delta)
            success = False
            for n_ls in range(12):
                vol.sub_(delta, alpha=(armijo - armijo_prev))
                armijo_prev = armijo
                delta_reg1 = 0.5 * armijo * (armijo * dd - 2 * dv)
                new_nll1 = self.derivatives_dist(contrast,
                                                 distortion,
                                                 intercept,
                                                 self.maps.decay,
                                                 do_grad=False)
                if new_nll1 + delta_reg1 <= nll1:
                    success = True
                    break
                armijo /= 2
            if not success:
                vol.add_(delta, alpha=armijo_prev)
                new_nll1 = nll1
                delta_reg1 = 0
            nll += new_nll1
            reg += reg1 + delta_reg1

            del delta

        # --- track general improvement --------------------------------
        self.nll['obs_prev'] = self.nll['obs']
        self.nll['obs'] = nll
        self.nll['vreg_prev'] = self.nll['vreg']
        self.nll['vreg'] = reg
        nll = self.print_nll()
        self.last_step = 'dist'
        return nll

    # ------------------------------------------------------------------
    #                        UPDATE WEIGHT MAP
    # ------------------------------------------------------------------

    def update_rls(self):
        if self.opt.regularization.norm not in ('tv', 'jtv'):
            return sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls'])
        rls, sumrls = update_rls(self.maps, self.lam,
                                 self.opt.regularization.norm)
        self.nll['rls_prev'] = self.nll['rls']
        self.nll['rls'] = 0.5 * sumrls
        self.last_step = 'rls'
        return sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls'])

    def print_header(self):
        if self.opt.verbose <= 0:
            return
        pstr = ''
        if self.opt.optim.max_iter_rls > 1:
            pstr += f'{"rls":^3s} | '
        if self.opt.optim.max_iter_gn > 1:
            pstr += f'{"gn":^3s} | '
        if self.opt.optim.max_iter_prm > 1 or self.opt.optim.max_iter_dist > 1:
            pstr += f'{"sub":^3s} | '
        pstr += f'{"step":^4s} | '
        pstr += f'{"fit":^12s} '
        if self.opt.regularization.norm:
            pstr += f'+ {"reg":^12s} '
        if self.opt.regularization.norm.endswith('tv'):
            pstr += f'+ {"rls":^12s} '
        if self.opt.distortion.enable:
            pstr += f'+ {"dist":^12s} '
        pstr += f'= {"crit":^12s}'
        print('\n' + pstr)
        print('-' * len(pstr))

    def print_nll(self):
        nll = sum(self.nll[k] for k in ['obs', 'reg', 'vreg', 'rls'])
        self.nll['all'].append(nll)
        if self.opt.verbose <= 0:
            return nll
        obs = self.nll['obs']
        reg = self.nll['reg']
        vreg = self.nll['vreg']
        rls = self.nll['rls']
        obs0 = self.nll['obs_prev']
        reg0 = self.nll['reg_prev']
        vreg0 = self.nll['vreg_prev']
        rls0 = self.nll['rls_prev']

        if self.last_step == 'rls':
            nll = obs0 + reg + rls + vreg
            if reg + rls <= reg0 + rls0:
                evol = '<='
            else:
                evol = '>'
        elif self.last_step == 'prm':
            nll = obs0 + reg + rls + vreg
            if obs + reg <= obs0 + reg0:
                evol = '<='
            else:
                evol = '>'
        elif self.last_step == 'dist':
            if obs + vreg <= obs0 + vreg0:
                evol = '<='
            else:
                evol = '>'
        else:
            evol = ''

        nll = obs + reg + vreg + rls
        pstr = ''
        if self.opt.optim.max_iter_rls > 1:
            pstr += f'{self.n_iter_rls:3d} | '
        if self.opt.optim.max_iter_gn > 1:
            pstr += f'{self.n_iter_gn:3d} | '
        if self.opt.optim.max_iter_prm > 1 or self.opt.optim.max_iter_dist > 1:
            if self.last_step == 'rls':
                pstr += f'{"-"*3} | '
            else:
                pstr += f'{self.n_iter_prm+self.n_iter_dist:3d} | '
        pstr += f'{self.last_step:4s} | '
        pstr += f'{obs:12.6g} '
        if self.opt.regularization.norm:
            pstr += f'+ {reg:12.6g} '
        if self.opt.regularization.norm.endswith('tv'):
            pstr += f'+ {rls:12.6g} '
        if self.opt.distortion.enable:
            pstr += f'+ {vreg:12.6g} '
        pstr += f'= {nll:12.6g} | '
        pstr += f'{evol}'
        print(pstr)
        self.show_maps()
        return nll

    def resize(self):
        affine, shape = spatial.affine_resize(self.affine0, self.shape0,
                                              1 / (2**(self.level - 1)))

        scl0 = spatial.voxel_size(self.affine0).prod()
        scl = spatial.voxel_size(affine).prod() / scl0
        self.lam_scale = scl

        for map in self.maps:
            map.volume = spatial.resize(map.volume[None, None, ...],
                                        shape=shape)[0, 0]
            map.affine = affine
        self.maps.affine = affine
        if self.rls is not None:
            if self.rls.dim() == len(shape):
                self.rls = spatial.resize(self.rls[None, None], hape=shape)[0,
                                                                            0]
            else:
                self.rls = spatial.resize(self.rls[None], shape=shape)[0]
            self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)

    def show_maps(self):
        if not self.opt.plot:
            return
        import matplotlib.pyplot as plt
        has_dist = any([d is not None for d in self.dist])
        ncol = max(len(self.maps), len(self.dist))
        for i, map in enumerate(self.maps):
            plt.subplot(2 + has_dist, ncol, i + 1)
            vol = map.volume[:, :, map.shape[-1] // 2]
            if i < len(self.maps) - 1:
                vol = vol.exp()
            plt.imshow(vol.cpu())
            plt.axis('off')
            if i < len(self.maps) - 1:
                plt.title('TE=0')
            else:
                plt.title('R2*')
            plt.colorbar()
        if has_dist:
            for i, (dat, dst) in enumerate(zip(self.data, self.dist)):
                if dst is None:
                    continue
                readout = dat.readout
                vol = dst.volume
                plt.subplot(2 + has_dist, ncol, i + 1 + ncol)
                vol = vol[:, :, dst.shape[-2] // 2]
                if readout is not None:
                    vol = vol[..., readout]
                else:
                    vol = vol.square().sum(-1).sqrt()
                plt.imshow(vol.cpu())
                plt.axis('off')
                plt.colorbar()
        plt.subplot(2 + has_dist, 1, 2 + has_dist)
        plt.plot(self.nll['all'])
        plt.show()
Пример #7
0
def preproc(data, transmit=None, receive=None, opt=None):
    """Estimate noise variance + register + compute recon space + init maps

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
    transmit : sequence[PrecomputedFieldMap], optional
    receive : sequence[PrecomputedFieldMap], optional
    opt : Options, optional

    Returns
    -------
    data : sequence[GradientEchoMulti]
    maps : ParametersMaps

    """

    opt = GREEQOptions().update(opt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- estimate hyper parameters ---
    logmeans = []
    te = []
    tr = []
    fa = []
    mt = []
    for c, contrast in enumerate(data):
        means = []
        vars = []
        for e, echo in enumerate(contrast):
            if opt.verbose:
                print(f'Estimate noise: contrast {c+1:d} - echo {e+1:2d}',
                      end='\r')
            dat = echo.fdata(**backend, rand=True, cache=False)
            sd0, sd1, mu0, mu1 = estimate_noise(dat)
            echo.mean = mu1.item()
            echo.sd = sd0.item()
            means.append(mu1)
            vars.append(sd0.square())
        means = torch.stack(means)
        vars = torch.stack(vars)
        var = (means * vars).sum() / means.sum()
        contrast.noise = var.item()

        te.append(contrast.te)
        tr.append(contrast.tr)
        fa.append(contrast.fa / 180 * core.constants.pi)
        mt.append(contrast.mt)
        logmeans.append(means.log())
    if opt.verbose:
        print('')

    print('Estimating maps from volumes:')
    for i in range(len(data)):
        print(f'    - Contrast {i:d}: ', end='')
        print(f'FA = {fa[i]*180/core.constants.pi:2.0f} deg  / ', end='')
        print(f'TR = {tr[i]*1e3:4.1f} ms / ', end='')
        print('TE = [' + ', '.join([f'{t*1e3:.1f}' for t in te[i]]) + '] ms',
              end='')
        if mt[i]:
            print(f' / MT = True', end='')
        print()

    # --- initial minifit ---
    print('Compute initial parameters')
    inter, r2s = _loglin_minifit(logmeans, te)
    pd, r1, mt = _rational_minifit(inter, tr, fa, mt)
    print(f'    - PD:    {pd.tolist():9.3g} a.u.')
    print(f'    - R1:    {r1.tolist():9.3g} 1/s')
    print(f'    - R2*:   {r2s.tolist():9.3g} 1/s')
    pd = pd.log()
    r1 = r1.log()
    r2s = r2s.log()
    if mt is not None:
        print(f'    - MT:    {100*mt.tolist():9.3g} %')
        mt = mt.log() - (1 - mt).log()

    # --- initial align ---
    transmit = core.utils.make_list(transmit or [])
    receive = core.utils.make_list(receive or [])
    if opt.preproc.register and len(data) > 1:
        print('Register volumes')
        data_reg = [(contrast.echo(0).fdata(rand=True, cache=False,
                                            **backend), contrast.affine)
                    for contrast in data]
        data_reg += [(map.magnitude.fdata(rand=True, cache=False,
                                          **backend), map.magnitude.affine)
                     for map in transmit]
        data_reg += [(map.magnitude.fdata(rand=True, cache=False,
                                          **backend), map.magnitude.affine)
                     for map in receive]
        dats, affines, _ = affine_align(data_reg, device=device)
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i + 1)
                plt.imshow(dats[i, :, dats.shape[2] // 2, :].cpu())
                plt.axis('off')
            plt.show()
        for contrast, aff in zip(data + transmit + receive, affines):
            aff, contrast.affine = core.utils.to_max_device(
                aff, contrast.affine)
            contrast.affine = torch.matmul(aff.inverse(), contrast.affine)

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape[1:] for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- allocate maps ---
    maps = GREEQParameterMaps()
    maps.pd = ParameterMap(mean_shape, fill=pd, affine=mean_affine, **backend)
    maps.r1 = ParameterMap(mean_shape, fill=r1, affine=mean_affine, **backend)
    maps.r2s = ParameterMap(mean_shape,
                            fill=r2s,
                            affine=mean_affine,
                            **backend)
    if mt is not None:
        maps.mt = ParameterMap(mean_shape,
                               fill=mt,
                               affine=mean_affine,
                               **backend)
    maps.affine = mean_affine

    # --- repeat fields if not enough ---
    if transmit:
        transmit = core.py.make_list(transmit, len(data))
    else:
        transmit = [None] * len(data)
    if receive:
        receive = core.py.make_list(receive, len(data))
    else:
        receive = [None] * len(data)

    return data, transmit, receive, maps
Пример #8
0
def vfa(data, transmit=None, receive=None, opt=None, **kwopt):
    """Compute PD, R1 (and MTsat) from two GRE at two flip angles using
    rational approximations of the Ernst equations.

    Parameters
    ----------
    data : sequence[GradientEcho]
        Volumes with different contrasts (flip angle or MT pulse) but with
        the same echo time. Note that they do not need to be real echoes;
        they often are images extrapolated to TE = 0.
    transmit : sequence[PrecomputedFieldMap], optional
        Map(s) of the transmit field (b1+). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
    receive : sequence[PrecomputedFieldMap], optional
        Map(s) of the receive field (b1-). If a single map is provided,
        it is used to correct all contrasts. If multiple maps are
        provided, there should be one for each contrast.
        If no receive map is provided, the output `pd` map will have
        a remaining b1- bias field.
    opt : VFAOptions
        Algorithm options.
        {'preproc': {'register':      True},     # Co-register contrasts
         'backend': {'dtype':  torch.float32,    # Data type
                     'device': 'cpu'},           # Device
         'verbose': 1,                           # Verbosity: 1=print, 2=plot
         'rational': False}                      # Force rational approximation

    Returns
    -------
    pd : ParameterMap
        Proton density (potentially, with R2* bias)
    r1 : ParameterMap
        Longitudinal relaxation rate
    mt : ParameterMap, optional
        Magnetization transfer

    References
    ----------
    ..[1] Tabelow et al., "hMRI - A toolbox for quantitative MRI in
          neuroscience and clinical research", NeuroImage (2019).
          https://www.sciencedirect.com/science/article/pii/S1053811919300291
    ..[2] Helms et al., "Quantitative FLASH MRI at 3T using a rational
          approximation of the Ernst equation", MRM (2008).
          https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732
    ..[3] Helms et al., "High-resolution maps of magnetization transfer
          with inherent correction for RF inhomogeneity and T1 relaxation
          obtained from 3D FLASH MRI", MRM (2008).
          https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21732

    """
    opt = VFAOptions().update(opt, **kwopt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    data = core.py.make_list(data)
    if len(data) < 2:
        raise ValueError('Expected at least two input images')
    transmit = core.py.make_list(transmit or [])
    receive = core.py.make_list(receive or [])

    # --- Copy instances to avoid modifying the inputs ---
    data = list(map(lambda x: x.copy(), data))
    transmit = list(map(lambda x: x.copy(), transmit))
    receive = list(map(lambda x: x.copy(), receive))

    # --- check TEs ---
    if len(set([contrast.te for contrast in data])) > 1:
        raise ValueError('Echo times not consistent across contrasts')

    # --- register ---
    if opt.preproc.register:
        print('Register volumes')
        data_reg = [(contrast.fdata(rand=True, cache=False, **backend),
                     contrast.affine) for contrast in data]
        data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend),
                      fmap.magnitude.affine) for fmap in transmit]
        data_reg += [(fmap.magnitude.fdata(rand=True, cache=False, **backend),
                      fmap.magnitude.affine) for fmap in receive]
        dats, affines, _ = affine_align(data_reg, device=device, fwhm=3)
        
        if opt.verbose > 1 and plt:
            plt.figure()
            for i in range(len(dats)):
                plt.subplot(1, len(dats), i+1)
                plt.imshow(dats[i, :, dats.shape[2]//2, :].cpu())
                plt.axis('off')
            plt.suptitle('Registered magnitude images')
            plt.show()
        del dats

        for vol, aff in zip(data + transmit + receive, affines):
            aff, vol.affine = core.utils.to_max_device(aff, vol.affine)
            vol.affine = torch.matmul(aff.inverse(), vol.affine)

    # --- repeat fields if not enough ---
    transmit = core.py.make_list(transmit or [None], len(data))
    receive = core.py.make_list(receive or [None], len(data))

    # --- compute recon space ---
    affines = [contrast.affine for contrast in data]
    shapes = [dat.volume.shape for dat in data]
    if opt.recon.affine is None:
        opt.recon.affine = opt.recon.space
    if opt.recon.fov is None:
        opt.recon.fov = opt.recon.space
    if isinstance(opt.recon.affine, int):
        mean_affine = affines[opt.recon.affine]
    else:
        mean_affine = torch.as_tensor(opt.recon.affine)
    if isinstance(opt.recon.fov, int):
        mean_shape = shapes[opt.recon.fov]
    else:
        mean_shape = tuple(opt.recon.fov)

    # --- compute PD/R1 ---
    pdt1 = [(id, contrast) for id, contrast in enumerate(data) if not contrast.mt]
    if len(pdt1) > 2:
        warnings.warn('More than two volumes could be used to compute PD+R1')
    pdt1 = pdt1[:2]
    if len(pdt1) < 2:
        raise ValueError('Not enough volumes to compute PD+R1')
    (pdw_idx, pdw_struct), (t1w_idx, t1w_struct) = pdt1
    
    if t1w_struct.te != pdw_struct.te:
        warnings.warn('Echo times not consistent across volumes')

    rational = opt .rational or t1w_struct.tr != pdw_struct.tr
    method = 'rational' if rational else 'analytical'
    print(f'Computing PD and R1 ({method}) from volumes:')
    print(f'  - '
          f'FA = {pdw_struct.fa:2.0f} deg  /  '
          f'TR = {pdw_struct.tr*1e3:4.1f} ms /  '
          f'TE = {pdw_struct.te*1e3:4.1f} ms')
    
    pdw = load_and_pull(pdw_struct, mean_affine, mean_shape, **backend)
    pdw_fa = pdw_struct.fa / 180. * core.constants.pi
    pdw_tr = pdw_struct.tr
    if receive[pdw_idx]:
        b1m = load_and_pull(receive[pdw_idx], mean_affine, mean_shape, **backend)
        unit = receive[pdw_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        pdw /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            pdw *= 100
        del b1m 
    if transmit[pdw_idx]:
        b1p = load_and_pull(transmit[pdw_idx], mean_affine, mean_shape, **backend)
        unit = transmit[pdw_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        pdw_fa = b1p * pdw_fa
        if unit in ('%', 'pct', 'p.u.'):
            pdw_fa /= 100
        del b1p
    pdw_fa = torch.as_tensor(pdw_fa, **backend)

    print(f'  - '
          f'FA = {t1w_struct.fa:2.0f} deg  /  '
          f'TR = {t1w_struct.tr*1e3:4.1f} ms /  '
          f'TE = {t1w_struct.te*1e3:4.1f} ms')
    
    t1w = load_and_pull(t1w_struct, mean_affine, mean_shape, **backend)
    t1w_fa = t1w_struct.fa / 180. * core.constants.pi
    t1w_tr = t1w_struct.tr
    if receive[t1w_idx]:
        b1m = load_and_pull(receive[t1w_idx], mean_affine, mean_shape, **backend)
        unit = receive[t1w_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        t1w /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            t1w *= 100
        del b1m
    if transmit[pdw_idx]:
        b1p = load_and_pull(transmit[t1w_idx], mean_affine, mean_shape, **backend)
        unit = transmit[t1w_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        t1w_fa = b1p * t1w_fa
        if unit in ('%', 'pct', 'p.u.'):
            t1w_fa /= 100
        del b1p
    t1w_fa = torch.as_tensor(t1w_fa, **backend)

    if rational:
        # we must use rational approximations
        r1 = 0.5 * (t1w * (t1w_fa / t1w_tr) - pdw * (pdw_fa / pdw_tr))
        r1 /= ((pdw / pdw_fa) - (t1w / t1w_fa))

        pd = (t1w * pdw) * (t1w_tr * (pdw_fa / t1w_fa) - pdw_tr * (t1w_fa / pdw_fa))
        pd /= (pdw * (pdw_tr * pdw_fa) - t1w * (t1w_tr * t1w_fa))
        del t1w_fa, pdw_fa, t1w, pdw
    else:
        # there is an analytical solution
        cosfa_t1w = t1w_fa.cos()
        sinfa_t1w = t1w_fa.sin_()
        del t1w_fa
        cosfa_pdw = pdw_fa.cos()
        sinfa_pdw = pdw_fa.sin_()
        del pdw_fa
        e1 = sinfa_pdw * t1w - sinfa_t1w * pdw
        e1 /= sinfa_pdw * cosfa_t1w * t1w - sinfa_t1w * cosfa_pdw * pdw

        pd = t1w / sinfa_t1w * (1 - cosfa_t1w * e1) / (1 - e1)
        pd += pdw / sinfa_pdw * (1 - cosfa_pdw * e1) / (1 - e1)
        pd /= 2.

        r1 = e1.log_().neg_().div_(t1w_tr)
        del t1w, pdw, cosfa_pdw, cosfa_t1w

    r1[~torch.isfinite(r1)] = 0
    pd[~torch.isfinite(pd)] = 0

    # --- compute MTsat ---
    mtw_struct = [(id, contrast) for id, contrast in enumerate(data)
                  if contrast.mt]
    if len(mtw_struct) == 0:
        return (ParameterMap(pd, affine=mean_affine, unit=None),
                ParameterMap(r1, affine=mean_affine, unit='1/s'))

    if len(mtw_struct) > 1:
        warnings.warn('More than one volume could be used to compute MTsat')
    mtw_idx, mtw_struct = mtw_struct[0]

    if mtw_struct.te != pdw_struct.te:
        warnings.warn('Echo times not consistent across volumes')

    method = 'rational' if opt.rational else 'analytical'
    print(f'Computing MTsat ({method}) from PD/R1 maps and volume:')
    print(f'  - '
          f'FA = {mtw_struct.fa:2.0f} deg  /  '
          f'TR = {mtw_struct.tr*1e3:4.1f} ms /  '
          f'TE = {mtw_struct.te*1e3:4.1f} ms')
    
    mtw = load_and_pull(mtw_struct, mean_affine, mean_shape, **backend)
    mtw_fa = mtw_struct.fa / 180. * core.constants.pi
    mtw_tr = mtw_struct.tr
    if receive[mtw_idx]:
        b1m = load_and_pull(receive[mtw_idx], mean_affine, mean_shape, **backend)
        unit = receive[mtw_idx].unit
        minval = b1m[b1m > 0].min()
        maxval = b1m[b1m > 0].max()
        meanval = b1m[b1m > 0].mean()
        print(f'    with B1- map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        mtw /= b1m
        if unit in ('%', 'pct', 'p.u.'):
            mtw *= 100
        del b1m
    if transmit[mtw_idx]:
        b1p = load_and_pull(transmit[mtw_idx], mean_affine, mean_shape)
        unit = transmit[mtw_idx].unit
        minval = b1p[b1p > 0].min()
        maxval = b1p[b1p > 0].max()
        meanval = b1p[b1p > 0].mean()
        print(f'    with B1+ map ('
              f'min= {minval:.2f}, '
              f'max = {maxval:.2f}, '
              f'mean = {meanval:.2f} {unit})')
        mtw_fa = b1p * mtw_fa
        if unit in ('%', 'pct', 'p.u.'):
            mtw_fa /= 100
        del b1p
    mtw_fa = torch.as_tensor(mtw_fa, **backend)

    if opt.rational:
        # we must use rational approximations
        mtsat = (mtw_fa * pd / mtw - 1) * r1 * mtw_tr - 0.5 * (mtw_fa ** 2)
        del mtw_fa, mtw
    else:
        # we have an analytical solution (work backward from PD/R1)
        cosfa_mtw = mtw_fa.cos()
        sinfa_mtw = mtw_fa.sin()
        del mtw_fa
        e1 = (-mtw_tr * r1).exp()
        mtsat = mtw / (cosfa_mtw * e1 * mtw + pd * sinfa_mtw * (1 - e1))
        del mtw, cosfa_mtw, sinfa_mtw
        mtsat = 1 - mtsat

    mtsat *= 100
    mtsat[~torch.isfinite(mtsat)] = 0

    return (ParameterMap(pd, affine=mean_affine, unit=None),
            ParameterMap(r1, affine=mean_affine, unit='1/s'),
            ParameterMap(mtsat, affine=mean_affine, unit='%'))
Пример #9
0
def _prepare(data, dist, opt):

    # --- options ------------------------------------------------------
    # we deepcopy all options so that we can overwrite/simplify them in place
    opt = ESTATICSOptions().update(opt).cleanup_()
    backend = dict(dtype=opt.backend.dtype, device=opt.backend.device)

    # --- be polite ----------------------------------------------------
    if len(data) > 1:
        pstr = f'Fitting a (shared) exponential decay model with {len(data)} contrasts.'
    else:
        pstr = f'Fitting an exponential decay model.'
    print(pstr)
    print('Echo times:')
    for i, contrast in enumerate(data):
        print(f'    - contrast {i:2d}: [' +
              ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms')

    # --- estimate noise / register / initialize maps ------------------
    data, maps, dist = preproc(data, dist, opt)
    nb_contrasts = len(maps) - 1

    if opt.distortion.enable:
        print('Readout directions:')
        for i, contrast in enumerate(data):
            layout = spatial.affine_to_layout(contrast.affine)
            layout = spatial.volume_layout_to_name(layout)
            readout = layout[contrast.readout]
            readout = ('left-right' if 'L' in readout or 'R' in readout else
                       'infra-supra' if 'I' in readout or 'S' in readout else
                       'antero-posterior'
                       if 'A' in readout or 'P' in readout else 'unknown')
            print(f'    - contrast {i:2d}: {readout}')

    # --- prepare regularization factor --------------------------------
    # 1. Parameter maps regularization
    #   -> we want lam = [*lam_intercepts, lam_decay]
    *lam, lam_decay = opt.regularization.factor
    lam = core.py.make_list(lam, nb_contrasts)
    lam.append(lam_decay)
    if not any(lam):
        opt.regularization.norm = ''
    opt.regularization.factor = lam
    # 2. Distortion fields regularization
    lam_dist = dict(factor=opt.distortion.factor,
                    absolute=opt.distortion.absolute,
                    membrane=opt.distortion.membrane,
                    bending=opt.distortion.bending)
    opt.distortion.factor = lam_dist

    # --- initialize weights (RLS) -------------------------------------
    mean_shape = maps.decay.volume.shape
    rls = None
    if opt.regularization.norm.endswith('tv'):
        rls_shape = mean_shape
        if opt.regularization.norm == 'tv':
            rls_shape = (len(maps), *rls_shape)
        rls = ParameterMap(rls_shape, fill=1, **backend).volume

    if opt.regularization.norm:
        print('Regularization:')
        print(f'    - type:           {opt.regularization.norm.upper()}')
        print(f'    - log intercepts: [' +
              ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']')
        print(f'    - decay:          {lam[-1]:.3g}')
    else:
        print('Without regularization')

    if opt.distortion.enable:
        print('Distortion correction:')
        print(f'    - model:          {opt.distortion.model.lower()}')
        print(
            f'    - absolute:       {opt.distortion.absolute * opt.distortion.factor["factor"]}'
        )
        print(
            f'    - membrane:       {opt.distortion.membrane * opt.distortion.factor["factor"]}'
        )
        print(
            f'    - bending:        {opt.distortion.bending * opt.distortion.factor["factor"]}'
        )

    else:
        print('Without distortion correction')

    # --- initialize nb of iterations ----------------------------------
    if not opt.regularization.norm.endswith('tv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_prm *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1
    print('Optimization:')
    print(f'    - Tolerance:        {opt.optim.tolerance}')
    if opt.regularization.norm.endswith('tv'):
        print(f'    - IRLS iterations:  {opt.optim.max_iter_rls}')
    print(f'    - Param iterations: {opt.optim.max_iter_prm}')
    if opt.distortion.enable:
        print(f'    - Dist iterations:  {opt.optim.max_iter_dist}')
    print(f'    - FMG cycles:       2')
    print(f'    - CG iterations:    {opt.optim.max_iter_cg}'
          f' (tolerance: {opt.optim.tolerance_cg})')
    if opt.optim.nb_levels > 1:
        print(f'    - Levels:           {opt.optim.nb_levels}')

    # ------------------------------------------------------------------
    #                     MAIN OPTIMIZATION LOOP
    # ------------------------------------------------------------------

    if opt.verbose:
        pstr = f'{"rls":^3s} | {"gn":^3s} | {"step":^4s} | '
        pstr += f'{"fit":^12s} + {"reg":^12s} + {"rls":^12s} '
        if opt.distortion.enable:
            pstr += f'+ {"dist":^12s} '
        pstr += f'= {"crit":^12s}'
        if opt.optim.nb_levels > 1:
            pstr = f'{"lvl":3s} | ' + pstr
        print('\n' + pstr)
        print('-' * len(pstr))

    return data, maps, dist, opt, rls
Пример #10
0
def nonlin(data, opt=None):
    """Fit the ESTATICS model to multi-echo Gradient-Echo data.

    Parameters
    ----------
    data : sequence[GradientEchoMulti]
        Observed GRE data.
    opt : Options, optional
        Algorithm options.

    Returns
    -------
    intecepts : sequence[GradientEcho]
        Echo series extrapolated to TE=0
    decay : estatics.ParameterMap
        R2* decay map

    """

    opt = ESTATICSOptions().update(opt)
    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    # --- be polite ---
    print(
        f'Fitting a (multi) exponential decay model with {len(data)} contrasts. Echo times:'
    )
    for i, contrast in enumerate(data):
        print(f'    - contrast {i:2d}: [' +
              ', '.join([f'{te*1e3:.1f}' for te in contrast.te]) + '] ms')

    # --- estimate noise / register / initialize maps ---
    data, maps = preproc(data, opt)
    print(maps.affine)
    vx = spatial.voxel_size(maps.affine)

    # --- prepare regularization factor ---
    lam = opt.regularization.factor
    lam = core.utils.make_list(lam)
    if len(lam) > 1:
        *lam, lam_decay = lam
    else:
        lam_decay = lam[0]
    lam = core.utils.make_list(lam, len(maps) - 1)
    lam.append(lam_decay)

    # --- initialize weights (RLS) ---
    if (not opt.regularization.norm
            or opt.regularization.norm.lower() == 'none'
            or all(l == 0 for l in lam)):
        opt.regularization.norm = ''
    opt.regularization.norm = opt.regularization.norm.lower()
    mean_shape = maps.decay.volume.shape
    rls = None
    sumrls = 0
    if opt.regularization.norm in ('tv', 'jtv'):
        rls_shape = mean_shape
        if opt.regularization.norm == 'tv':
            rls_shape = (len(maps), ) + rls_shape
        rls = ParameterMap(rls_shape, fill=1, **backend).volume
        sumrls = 0.5 * rls.sum(dtype=torch.double)

    if opt.regularization.norm:
        print(f'With {opt.regularization.norm.upper()} regularization:')
        print('    - log intercepts: [' +
              ', '.join([f'{i:.3g}' for i in lam[:-1]]) + ']')
        print(f'    - decay:          {lam[-1]:.3g}')
    else:
        print('Without regularization:')

    # --- compute derivatives ---
    grad = torch.empty((len(data) + 1, ) + mean_shape, **backend)
    hess = torch.empty((len(data) * 2 + 1, ) + mean_shape, **backend)

    if opt.regularization.norm not in ('tv', 'jtv'):
        # no reweighting -> do more gauss-newton updates instead
        opt.optim.max_iter_gn *= opt.optim.max_iter_rls
        opt.optim.max_iter_rls = 1

    if opt.verbose:
        print('{:^3s} | {:^3s} | {:^12s} + {:^12s} + {:^12s} = {:^12s}'.format(
            'rls', 'gn', 'fit', 'reg', 'rls', 'crit'))

    ll_rls = []
    ll_max = core.constants.ninf

    for n_iter_rls in range(opt.optim.max_iter_rls):

        multi_rls = rls if opt.regularization.norm == 'tv' \
                    else [rls] * len(maps)

        # --- Gauss Newton loop ---
        ll_gn = []
        for n_iter_gn in range(opt.optim.max_iter_gn):
            decay = maps.decay
            crit = 0
            grad.zero_()
            hess.zero_()

            # --- loop over contrasts ---
            for i, (contrast,
                    intercept) in enumerate(zip(data, maps.intercepts)):
                # compute gradient
                crit1, g1, h1 = _nonlin_gradient(contrast, intercept, decay,
                                                 opt)

                # increment
                gind = [i, -1]
                grad[gind, ...] += g1
                hind = [2 * i, -1, 2 * i + 1]
                hess[hind, ...] += h1
                crit += crit1

            # --- regularization ---
            reg = 0.
            if opt.regularization.norm:
                for i, (map, weight, l) in enumerate(zip(maps, multi_rls,
                                                         lam)):
                    if not l:
                        continue
                    reg1, g1 = _nonlin_reg(map.volume, vx, weight, l)
                    reg += reg1
                    grad[i] += g1

            # --- gauss-newton ---
            if not hess.isfinite().all():
                print('WARNING: NaNs in hess (??)')
            if opt.regularization.norm:
                hess = hessian_loaddiag(hess, 1e-6, 1e-8)
                deltas = _nonlin_solve(hess, grad, multi_rls, lam, vx, opt)
            else:
                hess = hessian_loaddiag(hess, 1e-6, 1e-8)
                deltas = hessian_solve(hess, grad)
            if not deltas.isfinite().all():
                print('WARNING: NaNs in delta (non stable Hessian)')

            for map, delta in zip(maps, deltas):
                map.volume -= delta
                if map.min is not None or map.max is not None:
                    map.volume.clamp_(map.min, map.max)

            # --- Compute gain ---
            ll = crit + reg + sumrls
            ll_max = max(ll_max, ll)
            ll_prev = ll_gn[-1] if ll_gn else ll_max
            gain = (ll_prev - ll) / (ll_max - ll_prev)
            ll_gn.append(ll)
            if opt.verbose:
                print(
                    '{:3d} | {:3d} | {:12.6g} + {:12.6g} + {:12.6g} = {:12.6g} | gain = {:7.2g}'
                    .format(n_iter_rls, n_iter_gn, crit, reg, sumrls,
                            crit + reg + sumrls, gain))
            if gain < opt.optim.tolerance_gn:
                break

        # --- Update RLS weights ---
        if opt.regularization.norm in ('tv', 'jtv'):
            rls = _nonlin_rls(maps, lam, opt.regularization.norm)
            sumrls = 0.5 * rls.sum(dtype=torch.double)
            eps = core.constants.eps(rls.dtype)
            rls = rls.clamp_min_(eps).reciprocal_()

            # --- Compute gain ---
            # (we are late by one full RLS iteration when computing the
            #  gain but we save some computations)
            ll = ll_gn[-1]
            ll_prev = ll_rls[-1][-1] if ll_rls else ll_max
            ll_rls.append(ll_gn)
            gain = (ll_prev - ll) / (ll_max - ll_prev)
            if gain < opt.optim.tolerance_rls:
                print(f'Converged ({gain:7.2g})')
                break

    # --- Prepare output ---
    return postproc(maps, data)
Пример #11
0
 def decay(self):
     if self.volume is None:
         return None
     return ParameterMap(self.volume[-1], affine=self.affine)