Beispiel #1
0
def _nonlin_rls(maps, lam=1., norm='jtv'):
    """Update the (L1) weights.

    Parameters
    ----------
    map : (P, *shape) ParameterMaps
        Parameter map
    lam : float or (P,) sequence[float], default=1
        Regularisation factor
    norm : {'tv', 'jtv'}, default='jtv'

    Returns
    -------
    rls : ([P], *shape) tensor
        Weights from the reweighted least squares scheme
    """

    if norm not in ('tv', 'jtv', '__internal__'):
        return None

    if isinstance(maps, ParameterMap):
        # single map
        # this should only be an internal call
        # -> we return the squared gradient map
        assert norm == '__internal__'
        vx = spatial.voxel_size(maps.affine)
        grad_fwd = spatial.diff(maps.fdata(),
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='f')
        grad_bwd = spatial.diff(maps.fdata(),
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='b')

        grad = grad_fwd.square_().sum(-1)
        grad += grad_bwd.square_().sum(-1)
        grad *= lam / 2.  # average across sides (2)
        return grad

    # multiple maps

    if norm == 'tv':
        rls = []
        for map, l in zip(maps, lam):
            rls1 = _nonlin_rls(map, l, '__internal__')
            rls1 = rls1.sqrt_()
            rls.append(rls1)
        return torch.stack(rls, dim=0)
    else:
        assert norm == 'jtv'
        rls = 0
        for map, l in zip(maps, lam):
            rls += _nonlin_rls(map, l, '__internal__')
        rls = rls.sqrt_()

    return rls
Beispiel #2
0
def reg(tensor, vx=1., rls=None, lam=1., do_grad=True):
    """Compute the gradient of the regularisation term.

    The regularisation term has the form:
    `0.5 * lam * sum(w[i] * (g+[i]**2 + g-[i]**2) / 2)`
    where `i` indexes a voxel, `lam` is the regularisation factor,
    `w[i]` is the RLS weight, `g+` and `g-` are the forward and
    backward spatial gradients of the parameter map.

    Parameters
    ----------
    tensor : (K, *shape) tensor
        Parameter map
    vx : float or sequence[float], default=1
        Voxel size
    rls : (K|1, *shape) tensor, optional
        Weights from the reweighted least squares scheme
    lam : float or sequence[float], default=1
        Regularisation factor
    do_grad : bool, default=True
        Return both the criterion and gradient

    Returns
    -------
    reg : () tensor[double]
        Regularisation term
    grad : (K, *shape) tensor
        Gradient with respect to the parameter map

    """
    nb_prm = tensor.shape[0]
    backend = dict(dtype=tensor.dtype, device=tensor.device)
    vx = core.utils.make_vector(vx, 3, **backend)
    lam = core.utils.make_vector(lam, nb_prm, **backend)

    grad_fwd = spatial.diff(tensor, dim=[1, 2, 3], voxel_size=vx, side='f')
    grad_bwd = spatial.diff(tensor, dim=[1, 2, 3], voxel_size=vx, side='b')
    if rls is not None:
        grad_fwd *= rls[..., None]
        grad_bwd *= rls[..., None]
    grad_fwd = spatial.div(grad_fwd, dim=[1, 2, 3], voxel_size=vx, side='f')
    grad_bwd = spatial.div(grad_bwd, dim=[1, 2, 3], voxel_size=vx, side='b')

    grad = grad_fwd
    grad += grad_bwd
    grad *= lam.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) / 2.
    # ^ average across side

    if do_grad:
        reg = (tensor * grad).sum(dtype=torch.double)
        return 0.5 * reg, grad
    else:
        grad *= tensor
        return 0.5 * grad.sum(dtype=torch.double)
Beispiel #3
0
def reg1(tensor, vx=1., rls=None, lam=1., do_grad=True):
    """Compute the gradient of the regularisation term.

    The regularisation term has the form:
    `0.5 * lam * sum(w[i] * (g+[i]**2 + g-[i]**2) / 2)`
    where `i` indexes a voxel, `lam` is the regularisation factor,
    `w[i]` is the RLS weight, `g+` and `g-` are the forward and
    backward spatial gradients of the parameter map.

    Parameters
    ----------
    tensor : (*shape) tensor
        Parameter map
    vx : float or sequence[float], default=1
        Voxel size
    rls : (*shape) tensor, optional
        Weights from the reweighted least squares scheme
    lam : float, default=1
        Regularisation factor
    do_grad : bool, default=True
        Return both the criterion and gradient

    Returns
    -------
    reg : () tensor[double]
        Regularisation term
    grad : (*shape) tensor
        Gradient with respect to the parameter map

    """

    grad_fwd = spatial.diff(tensor, dim=[0, 1, 2], voxel_size=vx, side='f')
    grad_bwd = spatial.diff(tensor, dim=[0, 1, 2], voxel_size=vx, side='b')
    if rls is not None:
        grad_fwd *= rls[..., None]
        grad_bwd *= rls[..., None]
    grad_fwd = spatial.div(grad_fwd, dim=[0, 1, 2], voxel_size=vx, side='f')
    grad_bwd = spatial.div(grad_bwd, dim=[0, 1, 2], voxel_size=vx, side='b')

    grad = grad_fwd
    grad += grad_bwd
    grad *= lam / 2.  # average across directions (3) and side (2)

    if do_grad:
        reg = (tensor * grad).sum(dtype=torch.double)
        return 0.5 * reg, grad
    else:
        grad *= tensor
        return 0.5 * grad.sum(dtype=torch.double)
Beispiel #4
0
def _loss_ssqd_jtv(dat_x,
                   dat_y,
                   tau,
                   lam,
                   voxel_size=1,
                   side='f',
                   bound='dct2'):
    """Computes an image denoising loss function, where:
    * fidelity term: sum-of-squared differences (SSQD)
    * regularisation term: joint total variation (JTV)
    * hyper-parameters: tau, lambda

    Parameters
    ----------
    dat_x : (dmx, dmy, dmz, nchannels) tensor
        Input image
    dat_y : (dmx, dmy, dmz, nchannels) tensor
        Reconstruction image
    tau : (nchannels) tensor
        Channel-specific noise precisions
    lam : (nchannels) tensor
        Channel-specific regularisation values
    voxel_size : float or sequence[float], default=1
        Unit size used in the denominator of the gradient.
    side : {'c', 'f', 'b'}, default='f'
        * 'c': central finite differences
        * 'f': forward finite differences
        * 'b': backward finite differences
    bound : {'dct2', 'dct1', 'dst2', 'dst1', 'dft', 'repeat', 'zero'}, default='dct2'
        Boundary condition.

    Returns
    ----------
    nll_yx : tensor
        Loss function value (negative log-posterior)

    """
    # compute negative log-likelihood (SSQD fidelity term)
    nll_xy = 0.5 * torch.sum(tau * torch.sum(
        (dat_x - dat_y)**2, dim=(0, 1, 2)))
    # compute gradients of reconstruction, shape=(dmx, dmy, dmz, nchannels, dmgr)
    nll_y = diff(dat_y,
                 order=1,
                 dim=(0, 1, 2),
                 voxel_size=voxel_size,
                 side=side,
                 bound=bound)
    # modulate channels with regularisation
    nll_y = lam[None, None, None, :, None] * nll_y
    # compute negative log-prior (JTV regularisation term)
    nll_y = torch.sum(
        nll_y**2 + eps(),
        dim=-1)  # to gradient magnitudes (sum over gradient directions)
    nll_y = torch.sum(nll_y, dim=-1)  # sum over reconstruction channels
    nll_y = torch.sqrt(nll_y)
    nll_y = torch.sum(nll_y)  # sum over voxels
    # compute negative log-posterior (loss function)
    nll_yx = nll_xy + nll_y

    return nll_yx
Beispiel #5
0
    def grad(self):
        """Compute the image gradients in each voxel.
        Almost equivalent to `self.pull_grad(identity)`.

        Returns
        -------
        grad : ([C], *spatial, dim) tensor

        """
        return spatial.diff(self.dat, dim=list(range(-self.dim, 0)),
                            bound=self.bound)
Beispiel #6
0
def derivatives_distortion(contrast,
                           distortion,
                           intercept,
                           decay,
                           opt,
                           do_grad=True):
    """Compute the gradient and Hessian of the distortion field.

    Parameters
    ----------
    contrast : (nb_echo, *obs_shape) GradientEchoMulti
        A single echo series (with the same weighting)
    distortion : ParameterizedDeformation
        A model of distortions caused by B0 inhomogeneities.
    intercept : (*recon_shape) ParameterMap
        Log-intercept of the contrast
    decay : (*recon_shape) ParameterMap
        Exponential decay
    opt : Options

    Returns
    -------
    crit : () tensor
        Log-likelihood
    grad : (*shape, 3) tensor
    hess : (*shape, 6) tensor

    """

    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    obs_shape = contrast.volume.shape[1:]
    recon_shape = intercept.volume.shape
    aff = core.linalg.lmdiv(intercept.affine, contrast.affine)
    aff = aff.to(**backend)
    lam = 1 / contrast.noise
    df = contrast.dof
    chi = opt.likelihood[0].lower() == 'c'

    # pull parameter maps to observed space
    grid = smart_grid(aff, obs_shape, recon_shape)
    inter = smart_pull(intercept.fdata(**backend), grid)
    slope = smart_pull(decay.fdata(**backend), grid)
    readout = contrast.readout
    if opt.distortion.te_scaling != 'pre':
        grid_up, grid_down = distortion.exp2(
            add_identity=not opt.distortion.te_scaling)
    else:
        grid_up = grid_down = None

    crit = 0
    grad = torch.zeros(obs_shape + (3, ), **backend) if do_grad else None
    hess = torch.zeros(obs_shape + (6, ), **backend) if do_grad else None

    te0 = 0
    for e, echo in enumerate(contrast):

        te = echo.te
        te0 = te0 or te
        blip = echo.blip or (2 * (e % 2) - 1)
        grid_blip = grid_up if blip > 0 else grid_down
        vscl = te / te0
        if opt.distortion.te_scaling == 'pre':
            vexp = distortion.iexp if blip < 0 else distortion.exp
            grid_blip = vexp(add_identity=True, alpha=vscl)
        elif opt.distortion.te_scaling:
            grid_blip = spatial.add_identity_grid_(vscl * grid_blip)

        # compute residuals
        dat = echo.fdata(**backend, rand=True, cache=False)  # observed
        fit = recon_fit(inter, slope, te)  # fitted
        if do_grad and isinstance(distortion, DenseDeformation):
            # D(fit) o phi
            gfit = smart_grad(fit, grid_blip, bound='dft', extrapolate=True)
        fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True)
        msk = get_mask_missing(dat, fit)  # mask of missing values
        if do_grad and isinstance(distortion, SVFDeformation):
            # D(fit o phi)
            gfit = spatial.diff(fit, bound='dft', dim=[-3, -2, -1])
            gfit.masked_fill_(msk.unsqueeze(-1), 0)
        dat.masked_fill_(msk, 0)
        fit.masked_fill_(msk, 0)
        msk = msk.bitwise_not_()

        if chi:
            crit1, res = nll_chi(dat, fit, msk, lam, df)
        else:
            crit1, res = nll_gauss(dat, fit, msk, lam)
        del dat, fit, msk
        crit += crit1

        if do_grad:
            g1 = res.unsqueeze(-1).mul(gfit)
            h1 = torch.zeros_like(hess)
            if readout is None:
                h1[..., :3] = gfit.square()
                h1[..., 3] = gfit[..., 0] * gfit[..., 1]
                h1[..., 4] = gfit[..., 0] * gfit[..., 2]
                h1[..., 5] = gfit[..., 1] * gfit[..., 2]
            else:
                h1[..., readout] = gfit[..., readout].square()

            # propagate backward
            if isinstance(distortion, SVFDeformation):
                vel = distortion.volume
                if opt.distortion.te_scaling == 'pre':
                    vel = ((-vscl) * vel) if blip < 0 else (vscl * vel)
                elif blip < 0:
                    vel = -vel
                g1, h1 = spatial.exp_backward(vel,
                                              g1,
                                              h1,
                                              steps=distortion.steps)

            alpha_g = alpha_h = lam
            alpha_g = alpha_g * blip
            if opt.distortion.te_scaling == 'pre':
                alpha_g = alpha_g * vscl
                alpha_h = alpha_h * (vscl * vscl)
            grad.add_(g1, alpha=alpha_g)
            hess.add_(h1, alpha=alpha_h)

    if not do_grad:
        return crit

    if readout is None:
        mask_nan_(grad)
        mask_nan_(hess[:-3], 1e-3)  # diagonal
        mask_nan_(hess[-3:])  # off-diagonal
    else:
        grad = grad[..., readout]
        hess = hess[..., readout]
        mask_nan_(grad)
        mask_nan_(hess)

    return crit, grad, hess
Beispiel #7
0
def update_rls(maps, lam=1., norm='jtv'):
    """Update the (L1) weights.

    Parameters
    ----------
    map : (P, *shape) ParameterMaps
        Parameter map
    lam : float or (P,) sequence[float], default=1
        Regularisation factor
    norm : {'tv', 'jtv'}, default='jtv'

    Returns
    -------
    rls : ([P], *shape) tensor
        (Inverted) Weights from the reweighted least squares scheme
    sumrls : () tensor
        Sum of the (non-inverted) weights
    """
    # vx = spatial.voxel_size(maps.affine)
    # return spatial.membrane_weights(maps.volume, dim=3, factor=lam,
    #                                 joint=(norm == 'jtv'), voxel_size=vx,
    #                                 return_sum=True)

    # ----------------------------------------------------------------
    # This is the old version of the code, before it got refactored
    # and generalized in the `spatial` module.
    # ----------------------------------------------------------------

    if norm not in ('tv', 'jtv', '__internal__'):
        return None

    if isinstance(maps, ParameterMap):
        # single map
        # this should only be an internal call
        # -> we return the squared gradient map
        assert norm == '__internal__'
        vx = spatial.voxel_size(maps.affine)
        grad_fwd = spatial.diff(maps.volume,
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='f')
        grad_bwd = spatial.diff(maps.volume,
                                dim=[0, 1, 2],
                                voxel_size=vx,
                                side='b')

        grad = grad_fwd.square_().sum(-1)
        grad += grad_bwd.square_().sum(-1)
        grad *= lam / 2  # average across side (2)
        return grad

    # multiple maps

    if norm == 'tv':
        rls = []
        for map, l in zip(maps, lam):
            rls1 = update_rls(map, l, '__internal__')
            rls1 = rls1.sqrt_()
            rls.append(rls1)
    else:
        assert norm == 'jtv'
        rls = 0
        for map, l in zip(maps, lam):
            rls += update_rls(map, l, '__internal__')
        rls = rls.sqrt_()

    sumrls = rls.sum(dtype=torch.double)
    eps = core.constants.eps(rls.dtype)
    rls = rls.clamp_min_(eps).reciprocal_()

    return rls, sumrls
Beispiel #8
0
def test_adjoint_3d(order, bound, side):
    u = torch.randn([64, 64, 64, 3], dtype=torch.double)
    v = torch.randn([64, 64, 64], dtype=torch.double)
    Lv = diff(v, dim=[0, 1, 2], side=side, order=order, bound=bound)
    Ku = div(u, dim=[0, 1, 2], side=side, order=order, bound=bound)
    assert torch.allclose((Lv*u).sum(), (Ku*v).sum())
Beispiel #9
0
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf):
    """Estimates full width at half maximum (FWHM) and noise standard
    deviation (sd) of a 2D or 3D image.

    It is assumed that the image has been generated as:
        dat = Ky + n,
    where K is Gaussian smoothing with some FWHM and n is
    additive Gaussian noise. FWHM and n are estimated.

    Parameters
    ----------
    dat : str or (*spatial) tensor
        Image data or path to nifti file
    vx : [sequence of] float, default=1
        Voxel size
    verbose : {0, 1, 2}, default=0
        Verbosity level:
            * 0: No verbosity
            * 1: Print FWHM and sd to screen
            * 2: 1 + show mask
    mn : float, optional
        Exclude values below
    mx : float, optional
        Exclude values above

    Returns
    -------
    fwhm : (dim,) tensor
        Estimated FWHM
    sd : scalar tensor
        Estimated noise standard deviation.

    References
    ----------
    ..[1] "Linked independent component analysis for multimodal data fusion."
          Appendix A
          Groves AR, Beckmann CF, Smith SM, Woolrich MW.
          Neuroimage. 2011 Feb 1;54(3):2198-217.

    """
    if isinstance(dat, str):
        dat = io.map(dat)
    if isinstance(dat, io.MappedArray):
        if vx is None:
            vx = get_voxel_size(dat.affine)
        dat = dat.fdata(rand=True, missing=0)
    dat = torch.as_tensor(dat)

    dim = dat.dim()
    if vx is None:
        vx = 1
    vx = utils.make_vector(vx, dim)
    backend = utils.backend(dat)
    # Make mask
    msk = (dat > mn).bitwise_and_(dat <= mx)
    dat = dat.masked_fill(~msk, 0)
    # TODO: we should erode the mask so that only voxels whose neighbours
    #       are in the mask are considered when computing gradients.
    if verbose >= 2:
        show_slices(msk)
    # Compute image gradient
    g = diff(dat, dim=range(dim), side='central', voxel_size=vx,
             bound='dft').abs_()
    slicer = (slice(1, -1), ) * dim
    g = g[(*slicer, None)]
    g[msk[slicer], :] = 0
    g = g.reshape([-1, dim]).sum(0, dtype=torch.double)
    # Make dat have zero mean
    dat = dat[slicer]
    dat = dat[msk[slicer]]
    x0 = dat - dat.mean()
    # Compute FWHM
    fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double)
    fwhm = fwhm / g
    if verbose >= 1:
        print(f'FWHM={fwhm.tolist()}')
    # Compute noise standard deviation
    sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0]
    sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0]
    sz = 1.0
    if dim == 3:
        sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0]
    sc = (sx * sy * sz) / dim
    sc.clamp_min_(1)
    sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc))
    if verbose >= 1:
        print(f'sd={sd.tolist()}')
    return fwhm, sd
Beispiel #10
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.kernel is None:
            self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm,
                                         **utils.backend(vel))
        grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm)
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            if self.mugrad is None:
                self.mugrad = spatial.diff(self.moving,
                                           dim=list(range(-dim, 0)),
                                           bound='dct2')
            if grad is not False:
                grad = grad.neg_()  # "final inverse" to "initial"
                grad = spatial.grid_push(grad, grid)
                grad = jg(self.mugrad, grad)
            if hess is not False:
                hess = spatial.grid_push(hess, grid)
                hess = jhj(self.mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Beispiel #11
0
 def compute_grad(dat):
     med = dat.reshape([dat.shape[0], -1]).median(dim=-1).values
     med = utils.unsqueeze(med, -1, 3)
     dat /= 0.5*med
     dat = spatial.diff(dat, dim=[1, 2, 3]).square().sum(-1)
     return dat