Esempio n. 1
0
def correct_smooth(x,
                   sigma=None,
                   lam=10,
                   gamma=10,
                   downsample=None,
                   max_iter=16,
                   max_rls=8,
                   tol=1e-6,
                   verbose=False,
                   device=None):
    """Correct the intensity non-uniformity in a SPIM image.

    The signal is modelled as: f = exp(s + b) + eps, with a penalty on
    the (Squared) gradients of s and on the (squared) curvature of b.

    Parameters
    ----------
    x : tensor
        SPIM image with the z dimension last and the z=0 plane first
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation on the signal.
    gamma : float, default=10
        Regularisation on the bias field.
    max_iter : int, default=16
        Maximum number of Newton iterations.
    max_rls : int, default=8
        Maximum number of reweighting iterations.
        If 1, this is effectively an l2 regularisation.
    tol : float, default=1e-6
        Tolerance for early stopping.
    verbose : int or bool, default=False
        Verbosity level
    device : torch.device, default=x.device
        Use this device during fitting.

    Returns
    -------
    y : tensor
        Fitted image
    bias : float
        Fitted bias
    x : float
        Corrected image

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    dim = x.dim()

    # downsampling
    if downsample:
        x0 = x
        downsample = py.make_list(downsample, dim)
        x = spatial.pool(dim, x, downsample)
    shape = x.shape
    x = x.to(device)

    # noise educated guess: assume SNR=5 at z=1/2
    center = tuple(slice(s // 3, 2 * s // 3) for s in shape)
    sigma = sigma or x[center].median() / 5
    lam = lam**2 * sigma**2
    gamma = gamma**2 * sigma**2
    regy = lambda y, w: spatial.regulariser(
        y[None], membrane=lam, dim=dim, weights=w)[0]
    regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0]
    solvey = lambda h, g, w: spatial.solve_field_sym(
        h[None], g[None], membrane=lam, dim=dim, weights=w)[0]
    solveb = lambda h, g: spatial.solve_field_sym(
        h[None], g[None], bending=gamma, dim=dim)[0]

    # init
    l1 = max_rls > 1
    if l1:
        w = torch.ones_like(x)[None]
        llw = w.sum()
        max_rls = 10
    else:
        w = None
        llw = 0
        max_rls = 1
    logb = torch.zeros_like(x)
    logy = x.clamp_min(1e-3).log_()
    y = logy.exp()
    b = logb.exp()
    fit = y * b
    res = fit - x
    llx = res.square().sum()
    lly = (regy(logy, w).mul_(logy)).sum()
    llb = (regb(logb).mul_(logb)).sum()
    ll0 = llx + lly + llb + llw
    ll1 = ll0

    for it_ls in range(max_rls):
        for it in range(max_iter):

            # update bias
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regb(logb)
            logb -= solveb(h, g)
            logb0 = logb.mean()
            logb -= logb0
            logy += logb0

            # update fit / ll
            llb = (regb(logb).mul_(logb)).sum()
            b = torch.exp(logb, out=b)
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x

            # update y
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regy(logy, w)
            logy -= solvey(h, g, w)

            # update fit / ll
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x
            lly = (regy(logy, w).mul_(logy)).sum()

            # compute objective
            llx = res.square().sum()
            ll = llx + lly + llb + llw
            gain = (ll1 - ll) / ll0
            ll1 = ll
            if verbose:
                end = '\n' if verbose > 1 else '\r'
                pre = f'{it_ls:3d} | ' if l1 else ''
                print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}',
                      end=end)
            if it > 0 and abs(gain) < tol:
                break

        if l1:
            w, llw = spatial.membrane_weights(logy[None],
                                              lam,
                                              dim=dim,
                                              return_sum=True)
            ll0 = ll
    if verbose:
        print('')

    if downsample:
        b = spatial.resize(logb.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        y = spatial.resize(logy.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        x = x0
    else:
        y = torch.exp(logy, out=y)
    x = x / b
    return y, b, x
Esempio n. 2
0
def zcorrect_square(x,
                    decay=None,
                    sigma=None,
                    lam=10,
                    max_iter=128,
                    tol=1e-6,
                    verbose=False):
    """Correct the z signal decay in a SPIM image.

    The signal is modelled as: f(z) = s(z) / (1 + b * z**2) + eps
    where z=0 is the top slice, s(z) is the theoretical signal if there
    was no absorption and b is the decay coefficient.

    Parameters
    ----------
    x : (..., nz) tensor
        SPIM image with the z dimension last and the z=0 plane first
    decay : float, optional
        Initial guess for decay parameter. Default: educated guess.
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation.
    max_iter : int, default=128
    tol : float, default=1e-6
    verbose : int or bool, default=False

    Returns
    -------
    y : tensor
        Corrected image
    decay : float
        Decay parameters

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    backend = utils.backend(x)
    shape = x.shape
    nz = shape[-1]
    x = x.reshape([-1, nz])
    b = decay

    # decay educated guess: closed form two values at z=1/3 and z=2/3
    z1 = nz // 3
    z2 = 2 * nz // 3
    x1 = x[:, z1].median()
    x2 = x[:, z2].median()
    z1 = float(z1)**2
    z2 = float(z2)**2
    b = b or (x2 - x1) / (x1 * z1 - x2 * z2)
    b = abs(b)
    y0 = x1 * (1 + b * z1)

    y0 = y0.item()
    b = b.item() if torch.is_tensor(b) else b

    # noise educated guess: assume SNR=5 at z=1/2
    sigma = sigma or (y0 / (1 + b * (nz / 2)**2)) / 5
    lam = lam**2 * sigma**2
    reg = lambda y: spatial.regulariser(y[:, None], membrane=lam, dim=1)[:, 0]
    solve = lambda h, g: spatial.solve_field_sym(
        h[:, None], g[:, None], membrane=lam, dim=1)[:, 0]

    print(y0, b, sigma, lam)

    # init
    z2 = torch.arange(nz, **backend).square_()
    logy = torch.full_like(x, y0).log_()
    logb = torch.as_tensor(b, **backend)
    y = logy.exp()
    b = logb.exp()
    ll0 = (y / (1 + b * z2) - x).square_().sum() + (logy * reg(logy)).sum()
    ll1 = ll0
    for it in range(max_iter):

        # exponentiate
        y = torch.exp(logy, out=y)
        fit = y / (1 + b * z2)
        res = fit - x

        # compute objective
        ll = res.square().sum() + (logy * reg(logy)).sum()
        gain = (ll1 - ll) / ll0
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end)
        if it > 0 and gain < tol:
            break
        ll1 = ll

        # update decay
        g = -(z2 * b * y) / (z2 * b + 1).square()
        h = b * y - z2 * b.square() * y
        h *= z2 / (z2 * b + 1).pow(3)
        h = h.abs_() * res.abs()
        h += g.square()
        g *= res

        g = g.sum()
        h = h.sum()
        logb -= g / h

        # update fit
        b = torch.exp(logb, out=b)
        fit = y / (1 + b * z2)
        res = fit - x

        # ll = (fit - x).square().sum() + 1e3 * (logy[1:] - logy[:-1]).sum().square()
        # gain = (ll1 - ll) / ll0
        # print(f'{it} | {ll.item()} | {gain.item()}', end='\n')

        # update y
        g = h = y / (z2 * b + 1)
        h = h.abs() * res.abs()
        h += g.square()
        g *= res
        g += reg(logy)
        logy -= solve(h, g)

    y = torch.exp(logy, out=y)
    y = y.reshape(shape)
    x = x * (1 + b * z2)
    x = x.reshape(shape)
    return y, b, x
Esempio n. 3
0
def grid_inv(grid, type='grid', lam=0.1, bound='dft', extrapolate=True):
    """Invert a dense deformation (or displacement) grid
    
    Notes
    -----
    The deformation/displacement grid must be expressed in 
    voxels, and map from/to the same lattice.
    
    Let `f = id + d` be the transformation. The inverse 
    is obtained as `id - (k * (f.T @ d)) / (k * (f.T @ 1))`
    where `k` is a smothing kernel, `f.T @ _` is the adjoint 
    operation ("push") of `f @ _` ("pull"). and `1` is an 
    image of ones.
    
    
    Parameters
    ----------
    grid : (..., *spatial, dim)
        Transformation (or displacement) grid
    type : {'grid', 'disp'}, default='grid'
        Type of deformation.
    lam : float, default=0.1
        Regularisation
    bound : str, default='dft'
    extrapolate : bool, default=True
        
    Returns
    -------
    grid_inv : (..., *spatial, dim)
        Inverse transformation (or displacement) grid
    
    """
    # get shape components
    dim = grid.shape[-1]
    shape = grid.shape[-(dim + 1):-1]
    batch = grid.shape[:-(dim + 1)]
    grid = grid.reshape([-1, *shape, dim])
    backend = dict(dtype=grid.dtype, device=grid.device)

    # get displacement
    identity = spatial.identity_grid(shape, **backend)
    if type == 'grid':
        disp = grid - identity
    else:
        disp = grid
        grid = disp + identity

    # push displacement
    push_opt = dict(bound=bound, extrapolate=extrapolate)
    disp = core.utils.movedim(disp, -1, 1)
    disp = spatial.grid_push(disp, grid, **push_opt)
    count = spatial.grid_count(grid, **push_opt)

    # Fill missing values using regularised least squares
    disp = spatial.solve_field_sym(count,
                                   disp,
                                   membrane=0.1,
                                   bound='dft',
                                   dim=dim)
    disp = core.utils.movedim(disp, 1, -1)
    disp = disp.reshape([*batch, *shape, dim])

    if type == 'grid':
        return identity - disp
    else:
        return -disp