def permute_axes(axes):
    """Transformation to permute images axes.

    Parameters
    ----------
    axes : str
        Target axes, to which the input images will be permuted.

    Returns
    -------
    Transform
        Returns a :class:`Transform` object whose `generator` will
        perform the axes permutation of `x`, `y`, and `mask`.

    """
    axes = axes_check_and_normalize(axes)
    def _generator(inputs):
        for x, y, axes_in, mask in inputs:
            axes_in = axes_check_and_normalize(axes_in)
            if axes_in != axes:
                # print('permuting axes from %s to %s' % (axes_in,axes))
                x = move_image_axes(x, axes_in, axes, True)
                y = move_image_axes(y, axes_in, axes, True)
                if mask is not None:
                    mask = move_image_axes(mask, axes_in, axes)
            yield x, y, axes, mask

    return Transform('Permute axes to %s' % axes, _generator, 1)
 def _generator(inputs):
     for x, y, axes_in, mask in inputs:
         axes_in = axes_check_and_normalize(axes_in)
         if axes_in != axes:
             # print('permuting axes from %s to %s' % (axes_in,axes))
             x = move_image_axes(x, axes_in, axes, True)
             y = move_image_axes(y, axes_in, axes, True)
             if mask is not None:
                 mask = move_image_axes(mask, axes_in, axes)
         yield x, y, axes, mask
    def _make_normalize_data(axes_in):
        """Move X to front of image."""
        axes_in  = axes_check_and_normalize(axes_in)
        axes_out = subsample_axis
        # (a in axes_in for a in 'XY') or _raise(ValueError('X and/or Y axis missing.'))
        # add axis in axes_in to axes_out (if it doesn't exist there)
        axes_out += ''.join(a for a in axes_in if a not in axes_out)

        def _normalize_data(data,undo=False):
            if undo:
                return move_image_axes(data, axes_out, axes_in)
            else:
                return move_image_axes(data, axes_in, axes_out)
        return _normalize_data
Beispiel #4
0
def create_patches_reduced_target(
        raw_data,
        patch_size,
        n_patches_per_image,
        reduction_axes,
        target_axes = None, # TODO: this should rather be part of RawData and also exposed to transforms
        **kwargs
    ):
    """Create normalized training data to be used for neural network training.

    In contrast to :func:`create_patches`, it is assumed that the target image has reduced
    dimensionality (i.e. size 1) along one or several axes (`reduction_axes`).

    Parameters
    ----------
    raw_data : :class:`RawData`
        See :func:`create_patches`.
    patch_size : tuple
        See :func:`create_patches`.
    n_patches_per_image : int
        See :func:`create_patches`.
    reduction_axes : str
        Axes where the target images have a reduced dimension (i.e. size 1) compared to the source images.
    target_axes : str
        Axes of the raw target images. If ``None``, will be assumed to be equal to that of the raw source images.
    kwargs : dict
        Additional parameters as in :func:`create_patches`.

    Returns
    -------
    tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
        See :func:`create_patches`. Note that the shape of the target data will be 1 along all reduction axes.

    """
    reduction_axes = axes_check_and_normalize(reduction_axes,disallowed='S')

    transforms = kwargs.get('transforms')
    if transforms is None:
        transforms = []
    transforms = list(transforms)
    transforms.insert(0,broadcast_target(target_axes))
    kwargs['transforms'] = transforms

    save_file = kwargs.pop('save_file',None)

    if any(s is None for s in patch_size):
        patch_axes = kwargs.get('patch_axes')
        if patch_axes is not None:
            _transforms = list(transforms)
            _transforms.append(permute_axes(patch_axes))
        else:
            _transforms = transforms
        tf = Transform(*zip(*_transforms))
        image_pairs = compose(*tf.generator)(raw_data.generator())
        x,y,axes,mask = next(image_pairs) # get the first entry from the generator
        patch_size = list(patch_size)
        for i,(a,s) in enumerate(zip(axes,patch_size)):
            if s is not None: continue
            a in reduction_axes or _raise(ValueError("entry of patch_size is None for non reduction axis %s." % a))
            patch_size[i] = x.shape[i]
        patch_size = tuple(patch_size)
        del x,y,axes,mask

    X,Y,axes = create_patches (
        raw_data            = raw_data,
        patch_size          = patch_size,
        n_patches_per_image = n_patches_per_image,
        **kwargs
    )

    ax = axes_dict(axes)
    for a in reduction_axes:
        a in axes or _raise(ValueError("reduction axis %d not present in extracted patches" % a))
        n_dims = Y.shape[ax[a]]
        if n_dims == 1:
            warnings.warn("extracted target patches already have dimensionality 1 along reduction axis %s." % a)
        else:
            t = np.take(Y,(1,),axis=ax[a])
            Y = np.take(Y,(0,),axis=ax[a])
            i = np.random.choice(Y.size,size=100)
            if not np.all(t.flat[i]==Y.flat[i]):
                warnings.warn("extracted target patches vary along reduction axis %s." % a)

    if save_file is not None:
        print('Saving data to %s.' % str(Path(save_file)))
        save_training_data(save_file, X, Y, axes)

    return X,Y,axes
Beispiel #5
0
def create_patches(
        raw_data,
        patch_size,
        n_patches_per_image,
        patch_axes    = None,
        save_file     = None,
        transforms    = None,
        patch_filter  = no_background_patches(),
        normalization = norm_percentiles(),
        shuffle       = True,
        verbose       = True,
    ):
    """Create normalized training data to be used for neural network training.

    Parameters
    ----------
    raw_data : :class:`RawData`
        Object that yields matching pairs of raw images.
    patch_size : tuple
        Shape of the patches to be extraced from raw images.
        Must be compatible with the number of dimensions and axes of the raw images.
        As a general rule, use a power of two along all XYZT axes, or at least divisible by 8.
    n_patches_per_image : int
        Number of patches to be sampled/extracted from each raw image pair (after transformations, see below).
    patch_axes : str or None
        Axes of the extracted patches. If ``None``, will assume to be equal to that of transformed raw data.
    save_file : str or None
        File name to save training data to disk in ``.npz`` format (see :func:`csbdeep.io.save_training_data`).
        If ``None``, data will not be saved.
    transforms : list or tuple, optional
        List of :class:`Transform` objects that apply additional transformations to the raw images.
        This can be used to augment the set of raw images (e.g., by including rotations).
        Set to ``None`` to disable. Default: ``None``.
    patch_filter : function, optional
        Function to determine for each image pair which patches are eligible to be extracted
        (default: :func:`no_background_patches`). Set to ``None`` to disable.
    normalization : function, optional
        Function that takes arguments `(patches_x, patches_y, x, y, mask, channel)`, whose purpose is to
        normalize the patches (`patches_x`, `patches_y`) extracted from the associated raw images
        (`x`, `y`, with `mask`; see :class:`RawData`). Default: :func:`norm_percentiles`.
    shuffle : bool, optional
        Randomly shuffle all extracted patches.
    verbose : bool, optional
        Display overview of images, transforms, etc.

    Returns
    -------
    tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
        Returns a tuple (`X`, `Y`, `axes`) with the normalized extracted patches from all (transformed) raw images
        and their axes.
        `X` is the array of patches extracted from source images with `Y` being the array of corresponding target patches.
        The shape of `X` and `Y` is as follows: `(n_total_patches, n_channels, ...)`.
        For single-channel images, `n_channels` will be 1.

    Raises
    ------
    ValueError
        Various reasons.

    Example
    -------
    >>> raw_data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='ZYX')
    >>> X, Y, XY_axes = create_patches(raw_data, patch_size=(32,128,128), n_patches_per_image=16)

    Todo
    ----
    - Save created patches directly to disk using :class:`numpy.memmap` or similar?
      Would allow to work with large data that doesn't fit in memory.

    """
    ## images and transforms
    if transforms is None:
        transforms = []
    transforms = list(transforms)
    if patch_axes is not None:
        transforms.append(permute_axes(patch_axes))
    if len(transforms) == 0:
        transforms.append(Transform.identity())


    image_pairs, n_raw_images = raw_data.generator(), raw_data.size
    tf = Transform(*zip(*transforms)) # convert list of Transforms into Transform of lists
    image_pairs = compose(*tf.generator)(image_pairs) # combine all transformations with raw images as input
    n_transforms = np.prod(tf.size)
    n_images = n_raw_images * n_transforms
    n_patches = n_images * n_patches_per_image
    n_required_memory_bytes = 2 * n_patches*np.prod(patch_size) * 4

    ## memory check
    _memory_check(n_required_memory_bytes)

    ## summary
    if verbose:
        print('='*66)
        print('%5d raw images x %4d transformations   = %5d images' % (n_raw_images,n_transforms,n_images))
        print('%5d images     x %4d patches per image = %5d patches in total' % (n_images,n_patches_per_image,n_patches))
        print('='*66)
        print('Input data:')
        print(raw_data.description)
        print('='*66)
        print('Transformations:')
        for t in transforms:
            print('{t.size} x {t.name}'.format(t=t))
        print('='*66)
        print('Patch size:')
        print(" x ".join(str(p) for p in patch_size))
        print('=' * 66)

    sys.stdout.flush()

    ## sample patches from each pair of transformed raw images
    X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
    Y = np.empty_like(X)

    for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images,disable=(not verbose)):
        if i >= n_images:
            warnings.warn('more raw images (or transformations thereof) than expected, skipping excess images.')
            break
        if i==0:
            axes = axes_check_and_normalize(_axes,len(patch_size))
            channel = axes_dict(axes)['C']
        # checks
        # len(axes) >= x.ndim or _raise(ValueError())
        axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
        x.shape == y.shape or _raise(ValueError())
        mask is None or mask.shape == x.shape or _raise(ValueError())
        (channel is None or (isinstance(channel,int) and 0<=channel<x.ndim)) or _raise(ValueError())
        channel is None or patch_size[channel]==x.shape[channel] or _raise(ValueError('extracted patches must contain all channels.'))

        _Y,_X = sample_patches_from_multiple_stacks((y,x), patch_size, n_patches_per_image, mask, patch_filter)

        s = slice(i*n_patches_per_image,(i+1)*n_patches_per_image)
        X[s], Y[s] = normalization(_X,_Y, x,y,mask,channel)

    if shuffle:
        shuffle_inplace(X,Y)

    axes = 'SC'+axes.replace('C','')
    if channel is None:
        X = np.expand_dims(X,1)
        Y = np.expand_dims(Y,1)
    else:
        X = np.moveaxis(X, 1+channel, 1)
        Y = np.moveaxis(Y, 1+channel, 1)

    if save_file is not None:
        print('Saving data to %s.' % str(Path(save_file)))
        save_training_data(save_file, X, Y, axes)

    return X,Y,axes
def anisotropic_distortions(
        subsample,
        psf,
        psf_axes       = None,
        poisson_noise  = False,
        gauss_sigma    = 0,
        subsample_axis = 'X',
        yield_target   = 'source',
        crop_threshold = 0.2,
    ):
    """Simulate anisotropic distortions.

    Modify the first image (obtained from input generator) along one axis to mimic the
    distortions that typically occur due to low resolution along the Z axis.
    Note that the modified image is finally upscaled to obtain the same resolution
    as the unmodified input image and is yielded as the 'source' image (see :class:`RawData`).
    The mask from the input generator is simply passed through.

    The following operations are applied to the image (in order):

    1. Convolution with PSF
    2. Poisson noise
    3. Gaussian noise
    4. Subsampling along ``subsample_axis``
    5. Upsampling along ``subsample_axis`` (to former size).


    Parameters
    ----------
    subsample : float
        Subsampling factor to mimic distortions along Z.
    psf : :class:`numpy.ndarray` or None
        Point spread function (PSF) that is supposed to mimic blurring
        of the microscope due to reduced axial resolution. Set to ``None`` to disable.
    psf_axes : str or None
        Axes of the PSF. If ``None``, psf axes are assumed to be the same as of the image
        that it is applied to.
    poisson_noise : bool
        Flag to indicate whether Poisson noise should be applied to the image.
    gauss_sigma : float
        Standard deviation of white Gaussian noise to be added to the image.
    subsample_axis : str
        Subsampling image axis (default X).
    yield_target : str
        Which image from the input generator should be yielded by the generator ('source' or 'target').
        If 'source', the unmodified input/source image (from which the distorted image is computed)
        is yielded as the target image. If 'target', the target image from the input generator is simply
        passed through.
    crop_threshold : float
        The subsample factor must evenly divide the image size along the subsampling axis to prevent
        potential image misalignment. If this is not the case the subsample factors are
        modified and the raw image may be cropped along the subsampling axis
        up to a fraction indicated by `crop_threshold`.

    Returns
    -------
    Transform
        Returns a :class:`Transform` object intended to be used with :func:`create_patches`.

    Raises
    ------
    ValueError
        Various reasons.

    """
    zoom_order = 1

    (np.isscalar(subsample) and subsample >= 1) or _raise(ValueError('subsample must be >= 1'))
    _subsample = subsample

    subsample_axis = axes_check_and_normalize(subsample_axis)
    len(subsample_axis)==1 or _raise(ValueError())

    psf is None or isinstance(psf,np.ndarray) or _raise(ValueError())
    if psf_axes is not None:
        psf_axes = axes_check_and_normalize(psf_axes)

    0 < crop_threshold < 1 or _raise(ValueError())

    yield_target in ('source','target') or _raise(ValueError())

    if psf is None and yield_target == 'source':
        warnings.warn(
            "It is strongly recommended to use an appropriate PSF to "
            "mimic the optical effects of the microscope. "
            "We found that training with synthesized anisotropic images "
            "that were created without a PSF "
            "can sometimes lead to unwanted artifacts in the reconstructed images."
        )


    def _make_normalize_data(axes_in):
        """Move X to front of image."""
        axes_in  = axes_check_and_normalize(axes_in)
        axes_out = subsample_axis
        # (a in axes_in for a in 'XY') or _raise(ValueError('X and/or Y axis missing.'))
        # add axis in axes_in to axes_out (if it doesn't exist there)
        axes_out += ''.join(a for a in axes_in if a not in axes_out)

        def _normalize_data(data,undo=False):
            if undo:
                return move_image_axes(data, axes_out, axes_in)
            else:
                return move_image_axes(data, axes_in, axes_out)
        return _normalize_data


    def _scale_down_up(data,subsample):
        from scipy.ndimage.interpolation import zoom
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            factor = np.ones(data.ndim)
            factor[0] = subsample
            return zoom(zoom(data, 1/factor, order=0),
                                     factor, order=zoom_order)


    def _adjust_subsample(d,s,c):
        """length d, subsample s, tolerated crop loss fraction c"""
        from fractions import Fraction

        def crop_size(n_digits,frac):
            _s = round(s,n_digits)
            _div = frac.denominator
            s_multiple_max = np.floor(d/_s)
            s_multiple = (s_multiple_max//_div)*_div
            # print(n_digits, _s,_div,s_multiple)
            size = s_multiple * _s
            assert np.allclose(size,round(size))
            return size

        def decimals(v,n_digits=None):
            if n_digits is not None:
                v = round(v,n_digits)
            s = str(v)
            assert '.' in s
            decimals = s[1+s.find('.'):]
            return int(decimals), len(decimals)

        s = float(s)
        dec, n_digits = decimals(s)
        frac = Fraction(dec,10**n_digits)
        # a multiple of s that is also an integer number must be
        # divisible by the denominator of the fraction that represents the decimal points

        # round off decimals points if needed
        while n_digits > 0 and (d-crop_size(n_digits,frac))/d > c:
            n_digits -= 1
            frac = Fraction(decimals(s,n_digits)[0], 10**n_digits)

        size = crop_size(n_digits,frac)
        if size == 0 or (d-size)/d > c:
            raise ValueError("subsample factor %g too large (crop_threshold=%g)" % (s,c))

        return round(s,n_digits), int(round(crop_size(n_digits,frac)))


    def _make_divisible_by_subsample(x,size):
        def _split_slice(v):
            return slice(None) if v==0 else slice(v//2,-(v-v//2))
        slices = [slice(None) for _ in x.shape]
        slices[0] = _split_slice(x.shape[0]-size)
        return x[slices]


    def _generator(inputs):
        for img,y,axes,mask in inputs:

            if yield_target == 'source':
                y is None or np.allclose(img,y) or warnings.warn("ignoring 'target' image from input generator")
                target = img
            else:
                target = y

            img.shape == target.shape or _raise(ValueError())

            axes = axes_check_and_normalize(axes)
            _normalize_data = _make_normalize_data(axes)
            # print(axes, img.shape)

            x = img.astype(np.float32, copy=False)

            if psf is not None:
                from scipy.signal import fftconvolve
                # print("blurring with psf")
                _psf = psf.astype(np.float32,copy=False)
                np.min(_psf) >= 0 or _raise(ValueError('psf has negative values.'))
                _psf /= np.sum(_psf)
                if psf_axes is not None:
                    _psf = move_image_axes(_psf, psf_axes, axes, True)
                x.ndim == _psf.ndim or _raise(ValueError('image and psf must have the same number of dimensions.'))

                if 'C' in axes:
                    ch = axes_dict(axes)['C']
                    n_channels = x.shape[ch]
                    # convolve with psf separately for every channel
                    if _psf.shape[ch] == 1:
                        warnings.warn('applying same psf to every channel of the image.')
                    if _psf.shape[ch] in (1,n_channels):
                        x = np.stack([
                            fftconvolve(
                                np.take(x,   i,axis=ch),
                                np.take(_psf,i,axis=ch,mode='clip'),
                                mode='same'
                            )
                            for i in range(n_channels)
                        ],axis=ch)
                    else:
                        raise ValueError('number of psf channels (%d) incompatible with number of image channels (%d).' % (_psf.shape[ch],n_channels))
                else:
                    x = fftconvolve(x, _psf, mode='same')

            if bool(poisson_noise):
                # print("apply poisson noise")
                x = np.random.poisson(np.maximum(0,x).astype(np.int)).astype(np.float32)

            if gauss_sigma > 0:
                # print("adding gaussian noise with sigma = ", gauss_sigma)
                noise = np.random.normal(0,gauss_sigma,size=x.shape).astype(np.float32)
                x = np.maximum(0,x+noise)

            if _subsample != 1:
                # print("down and upsampling X by factor %s" % str(_subsample))
                target = _normalize_data(target)
                x      = _normalize_data(x)

                subsample, subsample_size = _adjust_subsample(x.shape[0],_subsample,crop_threshold)
                # print(subsample, subsample_size)
                if _subsample != subsample:
                    warnings.warn('changing subsample from %s to %s' % (str(_subsample),str(subsample)))

                target = _make_divisible_by_subsample(target,subsample_size)
                x      = _make_divisible_by_subsample(x,     subsample_size)
                x      = _scale_down_up(x,subsample)

                assert x.shape == target.shape, (x.shape, target.shape)

                target = _normalize_data(target,undo=True)
                x      = _normalize_data(x,     undo=True)

            yield x, target, axes, mask


    return Transform('Anisotropic distortion (along %s axis)' % subsample_axis, _generator, 1)
 def _generator(inputs):
     for x, y, axes_x, mask in inputs:
         if target_axes is not None:
             axes_y = axes_check_and_normalize(target_axes,length=y.ndim)
             y = move_image_axes(y, axes_y, axes_x, True)
         yield x, np.broadcast_to(y,x.shape), axes_x, mask
 def _generator(inputs):
     for x, y, axes, mask in inputs:
         axes = axes_check_and_normalize(axes)
         len(axes) == len(slices) or _raise(ValueError())
         yield x[slices], y[slices], axes, (mask[slices] if mask is not None else None)
    def _generator(inputs):
        for img,y,axes,mask in inputs:

            if yield_target == 'source':
                y is None or np.allclose(img,y) or warnings.warn("ignoring 'target' image from input generator")
                target = img
            else:
                target = y

            img.shape == target.shape or _raise(ValueError())

            axes = axes_check_and_normalize(axes)
            _normalize_data = _make_normalize_data(axes)
            # print(axes, img.shape)

            x = img.astype(np.float32, copy=False)

            if psf is not None:
                from scipy.signal import fftconvolve
                # print("blurring with psf")
                _psf = psf.astype(np.float32,copy=False)
                np.min(_psf) >= 0 or _raise(ValueError('psf has negative values.'))
                _psf /= np.sum(_psf)
                if psf_axes is not None:
                    _psf = move_image_axes(_psf, psf_axes, axes, True)
                x.ndim == _psf.ndim or _raise(ValueError('image and psf must have the same number of dimensions.'))

                if 'C' in axes:
                    ch = axes_dict(axes)['C']
                    n_channels = x.shape[ch]
                    # convolve with psf separately for every channel
                    if _psf.shape[ch] == 1:
                        warnings.warn('applying same psf to every channel of the image.')
                    if _psf.shape[ch] in (1,n_channels):
                        x = np.stack([
                            fftconvolve(
                                np.take(x,   i,axis=ch),
                                np.take(_psf,i,axis=ch,mode='clip'),
                                mode='same'
                            )
                            for i in range(n_channels)
                        ],axis=ch)
                    else:
                        raise ValueError('number of psf channels (%d) incompatible with number of image channels (%d).' % (_psf.shape[ch],n_channels))
                else:
                    x = fftconvolve(x, _psf, mode='same')

            if bool(poisson_noise):
                # print("apply poisson noise")
                x = np.random.poisson(np.maximum(0,x).astype(np.int)).astype(np.float32)

            if gauss_sigma > 0:
                # print("adding gaussian noise with sigma = ", gauss_sigma)
                noise = np.random.normal(0,gauss_sigma,size=x.shape).astype(np.float32)
                x = np.maximum(0,x+noise)

            if _subsample != 1:
                # print("down and upsampling X by factor %s" % str(_subsample))
                target = _normalize_data(target)
                x      = _normalize_data(x)

                subsample, subsample_size = _adjust_subsample(x.shape[0],_subsample,crop_threshold)
                # print(subsample, subsample_size)
                if _subsample != subsample:
                    warnings.warn('changing subsample from %s to %s' % (str(_subsample),str(subsample)))

                target = _make_divisible_by_subsample(target,subsample_size)
                x      = _make_divisible_by_subsample(x,     subsample_size)
                x      = _scale_down_up(x,subsample)

                assert x.shape == target.shape, (x.shape, target.shape)

                target = _normalize_data(target,undo=True)
                x      = _normalize_data(x,     undo=True)

            yield x, target, axes, mask