Exemple #1
0
 def _build_pyramid(self, dat, levels, method, dim, bound,
                    mask=None, preview=None):
     levels = list(levels)
     indexed_levels = list(enumerate(levels))
     indexed_levels.sort(key=lambda x: x[1])
     nb_levels = max(levels)
     if mask is not None:
         mask = mask.to(dat.device)
     dats = [dat] * levels.count(0)
     masks = [mask] * levels.count(0)
     previews = [preview] * levels.count(0)
     if mask is not None:
         mask = mask.to(dat.dtype)
     if preview is not None:
         preview = preview.to(dat.dtype)
     for level in range(1, nb_levels+1):
         shape = dat.shape[-dim:]
         kernel_size = [min(2, s) for s in shape]
         if method[0] == 'g':  # gaussian pyramid
             # We assume the original data has a PSF of 1 input voxel.
             # We smooth by an additional 1-vx FWHM so that the data has a
             # PSF of 2 input voxels == 1 output voxel, then subsample.
             smooth = lambda x: spatial.smooth(x, fwhm=1, stride=2,
                                               dim=dim, bound=bound)
         elif method[0] == 'a':  # average window
             smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size,
                                             stride=2, reduction='mean')
         elif method[0] == 'm':  # median window
             smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size,
                                             stride=2, reduction='median')
         elif method[0] == 's':  # strides
             slicer = [slice(None, None, 2)] * dim
             smooth = lambda x: x[(Ellipsis, *slicer)]
         else:
             raise ValueError(method)
         dat = smooth(dat)
         if mask is not None:
             mask = smooth(mask)
         if preview is not None:
             preview = smooth(preview)
         dats += [dat] * levels.count(level)
         masks += [mask] * levels.count(level)
         previews += [preview] * levels.count(level)
     reordered_dats = [None] * len(levels)
     reordered_masks = [None] * len(levels)
     reordered_previews = [None] * len(levels)
     for (i, level), dat, mask, preview \
             in zip(indexed_levels, dats, masks, previews):
         reordered_dats[i] = dat
         reordered_masks[i] = mask
         reordered_previews[i] = preview
     return reordered_dats, reordered_masks, reordered_previews
Exemple #2
0
def downsample(x, aff_in, vx_out):
    """
    Downsample an image (by an integer factor) to approximately
    match a target voxel size
    """
    vx_in = spatial.voxel_size(aff_in)
    dim = len(vx_in)
    vx_out = utils.make_vector(vx_out, dim)
    factor = (vx_out / vx_in).clamp_min(1).floor().long()
    if (factor == 1).all():
        return x, aff_in
    factor = factor.tolist()
    x, aff_out = spatial.pool(dim, x, factor, affine=aff_in)
    return x, aff_out
Exemple #3
0
    def forward(self, x, **overload):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial) tensor
            Tensor to pool
        overload : dict
            Most parameters defined at build time can be overriden at
            call time

        Returns
        -------
        x : (batch, channel, *spatial_out) tensor
            Pooled tensor
        indices : (batch, channel, *spatial_out, dim) tensor, if `return_indices`
            Indices of input elements.

        """

        dim = self.dim
        kernel_size = make_list(overload.get('kernel_size', self.kernel_size),
                                dim)
        stride = make_list(overload.get('stride', self.stride), dim)
        padding = make_list(overload.get('padding', self.padding), dim)
        dilation = make_list(overload.get('dilation', self.dilation), dim)
        reduction = overload.get('reduction', self.reduction)
        return_indices = overload.get('return_indices', self.return_indices)

        # Activation
        activation = overload.get('activation', self.activation)
        if isinstance(activation, str):
            activation = _map_activations.get(activation.lower(), None)
        activation = (activation() if inspect.isclass(activation) else
                      activation if callable(activation) else None)

        x = pool(dim,
                 x,
                 kernel_size=kernel_size,
                 stride=stride,
                 dilation=dilation,
                 padding=padding,
                 reduction=reduction,
                 return_indices=return_indices)

        if activation:
            x = activation(x)
        return x
Exemple #4
0
    def forward(self, x, return_indices=None):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial) tensor
            Tensor to pool
        return_indices : bool, default=self.return_indices

        Returns
        -------
        x : (batch, channel, *spatial_out) tensor
            Pooled tensor
        indices : (batch, channel, *spatial_out, dim) tensor, if `return_indices`
            Indices of input elements.

        """
        return_indices = self.return_indices
        if return_indices is None:
            return_indices = self.return_indices

        x = pool(self.dim,
                 x,
                 kernel_size=self.kernel_size,
                 stride=self.stride,
                 dilation=self.dilation,
                 padding=self.padding,
                 reduction=self.reduction,
                 return_indices=return_indices,
                 ceil=self.ceil)
        if return_indices:
            x, ind = x

        if self.activation:
            x = self.activation(x)
        return (x, ind) if return_indices else x
Exemple #5
0
def pool(inp,
         window=3,
         stride=None,
         method='mean',
         dim=3,
         output=None,
         device=None):
    """Pool a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    window : [sequence of] int, default=3
        Window size
    stride : [sequence of] int, optional
        Stride between output elements.
        By default, it is the same as `window`.
    method : {'mean', 'sum', 'min', 'max', 'median'}, default='mean'
        Pooling function.
    dim : int, default=3
        Number of spatial dimensions.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the pooled data and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.pool{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp
    dat = dat.to(device)
    dim = dim or aff0.shape[-1] - 1

    # `pool` needs the spatial dimensions at the end
    spatial_in = dat.shape[:dim]
    batch = dat.shape[dim:]
    dat = dat.reshape([*spatial_in, -1])
    dat = utils.movedim(dat, -1, 0)
    dat, aff = spatial.pool(dim,
                            dat,
                            kernel_size=window,
                            stride=stride,
                            reduction=method,
                            affine=aff0)
    dat = utils.movedim(dat, 0, -1)
    dat = dat.reshape([*dat.shape[:dim], *batch])

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if is_file:
        return output
    else:
        return dat, aff
Exemple #6
0
def correct_smooth(x,
                   sigma=None,
                   lam=10,
                   gamma=10,
                   downsample=None,
                   max_iter=16,
                   max_rls=8,
                   tol=1e-6,
                   verbose=False,
                   device=None):
    """Correct the intensity non-uniformity in a SPIM image.

    The signal is modelled as: f = exp(s + b) + eps, with a penalty on
    the (Squared) gradients of s and on the (squared) curvature of b.

    Parameters
    ----------
    x : tensor
        SPIM image with the z dimension last and the z=0 plane first
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation on the signal.
    gamma : float, default=10
        Regularisation on the bias field.
    max_iter : int, default=16
        Maximum number of Newton iterations.
    max_rls : int, default=8
        Maximum number of reweighting iterations.
        If 1, this is effectively an l2 regularisation.
    tol : float, default=1e-6
        Tolerance for early stopping.
    verbose : int or bool, default=False
        Verbosity level
    device : torch.device, default=x.device
        Use this device during fitting.

    Returns
    -------
    y : tensor
        Fitted image
    bias : float
        Fitted bias
    x : float
        Corrected image

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    dim = x.dim()

    # downsampling
    if downsample:
        x0 = x
        downsample = py.make_list(downsample, dim)
        x = spatial.pool(dim, x, downsample)
    shape = x.shape
    x = x.to(device)

    # noise educated guess: assume SNR=5 at z=1/2
    center = tuple(slice(s // 3, 2 * s // 3) for s in shape)
    sigma = sigma or x[center].median() / 5
    lam = lam**2 * sigma**2
    gamma = gamma**2 * sigma**2
    regy = lambda y, w: spatial.regulariser(
        y[None], membrane=lam, dim=dim, weights=w)[0]
    regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0]
    solvey = lambda h, g, w: spatial.solve_field_sym(
        h[None], g[None], membrane=lam, dim=dim, weights=w)[0]
    solveb = lambda h, g: spatial.solve_field_sym(
        h[None], g[None], bending=gamma, dim=dim)[0]

    # init
    l1 = max_rls > 1
    if l1:
        w = torch.ones_like(x)[None]
        llw = w.sum()
        max_rls = 10
    else:
        w = None
        llw = 0
        max_rls = 1
    logb = torch.zeros_like(x)
    logy = x.clamp_min(1e-3).log_()
    y = logy.exp()
    b = logb.exp()
    fit = y * b
    res = fit - x
    llx = res.square().sum()
    lly = (regy(logy, w).mul_(logy)).sum()
    llb = (regb(logb).mul_(logb)).sum()
    ll0 = llx + lly + llb + llw
    ll1 = ll0

    for it_ls in range(max_rls):
        for it in range(max_iter):

            # update bias
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regb(logb)
            logb -= solveb(h, g)
            logb0 = logb.mean()
            logb -= logb0
            logy += logb0

            # update fit / ll
            llb = (regb(logb).mul_(logb)).sum()
            b = torch.exp(logb, out=b)
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x

            # update y
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regy(logy, w)
            logy -= solvey(h, g, w)

            # update fit / ll
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x
            lly = (regy(logy, w).mul_(logy)).sum()

            # compute objective
            llx = res.square().sum()
            ll = llx + lly + llb + llw
            gain = (ll1 - ll) / ll0
            ll1 = ll
            if verbose:
                end = '\n' if verbose > 1 else '\r'
                pre = f'{it_ls:3d} | ' if l1 else ''
                print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}',
                      end=end)
            if it > 0 and abs(gain) < tol:
                break

        if l1:
            w, llw = spatial.membrane_weights(logy[None],
                                              lam,
                                              dim=dim,
                                              return_sum=True)
            ll0 = ll
    if verbose:
        print('')

    if downsample:
        b = spatial.resize(logb.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        y = spatial.resize(logy.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        x = x0
    else:
        y = torch.exp(logy, out=y)
    x = x / b
    return y, b, x