Example #1
0
 def _axes_div_by(self, query_axes):
     if self.config.backbone == "unet":
         query_axes = axes_check_and_normalize(query_axes)
         assert len(self.config.unet_pool) == len(self.config.grid)
         div_by = dict(
             zip(
                 self.config.axes.replace('C', ''),
                 tuple(p**self.config.unet_n_depth * g for p, g in zip(
                     self.config.unet_pool, self.config.grid))))
         return tuple(div_by.get(a, 1) for a in query_axes)
     elif self.config.backbone == "resnet":
         grid_dict = dict(
             zip(self.config.axes.replace('C', ''), self.config.grid))
         return tuple(grid_dict.get(a, 1) for a in query_axes)
     else:
         raise NotImplementedError()
Example #2
0
 def _axes_change(value: str):
     if value != value.upper():
         with plugin.axes.changed.blocked():
             plugin.axes.value = value.upper()
     image = plugin.image.value
     axes = ""
     try:
         image is not None or _raise(ValueError("no image selected"))
         axes = axes_check_and_normalize(value,
                                         length=get_data(image).ndim,
                                         disallowed="S")
         update("image_axes", True, (axes, image, None))
     except ValueError as err:
         update("image_axes", False, (value, image, err))
     finally:
         widgets_inactive(plugin.timelapse_opts, active=("T" in axes))
Example #3
0
 def after(self, x, axes):
     # axes can include 'C', which may not have been present in before()
     axes = axes_check_and_normalize(axes, x.ndim)
     assert all(s_pad == s * g
                for s, s_pad, g in zip(x.shape, (
                    self.padded_shape.get(a, _s)
                    for a, _s in zip(axes, x.shape)), (self.grid.get(a, 1)
                                                       for a in axes)))
     # print(self.padded_shape)
     # print(self.pad)
     # print(self.grid)
     crop = tuple(
         slice(0, -(math.floor(p[1] / g)) if p[1] >= g else None)
         for p, g in zip((self.pad.get(a, (0, 0))
                          for a in axes), (self.grid.get(a, 1)
                                           for a in axes)))
     # print(crop)
     return x[crop]
Example #4
0
    def before(self, x, axes, axes_div_by):
        assert all(a % g == 0 for g, a in zip((self.grid.get(a, 1)
                                               for a in axes), axes_div_by))
        axes = axes_check_and_normalize(axes, x.ndim)

        def _split(v):
            return 0, v  # only pad at the end

        self.pad = {
            a: _split((div_n - s % div_n) % div_n)
            for a, div_n, s in zip(axes, axes_div_by, x.shape)
        }
        x_pad = np.pad(x,
                       tuple(self.pad[a] for a in axes),
                       mode=self.mode,
                       **self.kwargs)
        self.padded_shape = dict(zip(axes, x_pad.shape))
        if 'C' in self.padded_shape: del self.padded_shape['C']
        return x_pad
Example #5
0
File: big.py Project: ylch/stardist
    def cover(shape, axes, block_size, min_overlap, context, grid=1):
        """Return grid-aligned n-dimensional blocks to cover region
        of the given shape with axes semantics.

        Parameters block_size, min_overlap, and context can be different per
        dimension/axis (if provided as list) or the same (if provided as
        scalar value).

        Also see `Block.cover`.

        """
        shape = tuple(shape)
        n = len(shape)
        axes = axes_check_and_normalize(axes, length=n)
        if np.isscalar(block_size):  block_size  = n*[block_size]
        if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
        if np.isscalar(context):     context     = n*[context]
        if np.isscalar(grid):        grid        = n*[grid]
        assert n == len(block_size) == len(min_overlap) == len(context) == len(grid)

        # compute cover for each dimension
        cover_1d = [Block.cover(*args) for args in zip(shape, block_size, min_overlap, context, grid)]
        # return cover as Cartesian product of 1-dimensional blocks
        return tuple(BlockND(i,blocks,axes) for i,blocks in enumerate(product(*cover_1d)))
Example #6
0
    def __init__(self, X, **kwargs):
        """See class docstring."""

        assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2
        n_channel_in = X.shape[-1]
        n_channel_out = n_channel_in
        mean = np.mean(X)
        std = np.std(X)

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
        axes = axes.replace('S','') # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
            else:
                axes = 'C'+axes

        # normalization parameters
        self.mean                  = str(mean)
        self.std                   = str(std)
        # directly set by parameters
        self.n_dim                 = n_dim
        self.axes                  = axes
        self.n_channel_in          = int(n_channel_in)
        self.n_channel_out         = int(n_channel_out)

        # default config (can be overwritten by kwargs below)
        self.unet_residual         = False
        self.unet_n_depth          = 2
        self.unet_kern_size        = 5 if self.n_dim==2 else 3
        self.unet_n_first          = 32
        self.unet_last_activation  = 'linear'
        if backend_channels_last():
            self.unet_input_shape  = self.n_dim*(None,) + (self.n_channel_in,)
        else:
            self.unet_input_shape  = (self.n_channel_in,) + self.n_dim*(None,)

        self.train_loss            = 'mae'
        self.train_epochs          = 100
        self.train_steps_per_epoch = 400
        self.train_learning_rate   = 0.0004
        self.train_batch_size      = 16
        self.train_tensorboard     = True
        self.train_checkpoint      = 'weights_best.h5'
        self.train_reduce_lr       = {'factor': 0.5, 'patience': 10}
        self.batch_norm            = True
        self.n2v_perc_pix           = 1.5
        self.n2v_patch_shape       = (64, 64) if self.n_dim==2 else (64, 64, 64)
        self.n2v_manipulator       = 'uniform_withCP'
        self.n2v_neighborhood_radius = 5

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Example #7
0
 def blocks_for_axes(self, axes=None):
     axes = self.axes if axes is None else axes_check_and_normalize(axes)
     return tuple(self.axis_to_block[a] for a in axes)
Example #8
0
 def __init__(self, id, blocks, axes):
     self.id = id
     self.blocks = tuple(blocks)
     self.axes = axes_check_and_normalize(axes, length=len(self.blocks))
     self.axis_to_block = dict(zip(self.axes, self.blocks))
Example #9
0
    def predict_instances_big(self, img, axes, block_size, min_overlap, context=None,
                              labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs):
        """Predict instance segmentation from very large input images.

        Intended to be used when `predict_instances` cannot be used due to memory limitations.
        This function will break the input image into blocks and process them individually
        via `predict_instances` and assemble all the partial results. If used as intended, the result
        should be the same as if `predict_instances` was used directly on the whole image.

        **Important**: The crucial assumption is that all predicted object instances are smaller than
                       the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size.

        Example
        -------
        >>> img.shape
        (20000, 20000)
        >>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096,
                                                        min_overlap=128, context=128, n_tiles=(4,4))

        Parameters
        ----------
        img: :class:`numpy.ndarray` or similar
            Input image
        axes: str
            Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.)
        block_size: int or iterable of int
            Process input image in blocks of the provided shape.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        min_overlap: int or iterable of int
            Amount of guaranteed overlap between blocks.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        context: int or iterable of int, or None
            Amount of image context on all sides of a block, which is discarded.
            If None, uses an automatic estimate that should work in many cases.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        labels_out: :class:`numpy.ndarray` or similar, or None, or False
            numpy array or similar (must be of correct shape) to which the label image is written.
            If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``.
            If False, will not write the label image (useful if only the dictionary is needed).
        labels_out_dtype: str or dtype
            Data type of returned label image if ``labels_out=None`` (has no effect otherwise).
        show_progress: bool
            Show progress bar for block processing.
        kwargs: dict
            Keyword arguments for ``predict_instances``.

        Returns
        -------
        (:class:`numpy.ndarray` or False, dict)
            Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra.

        """
        from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
        from ..matching import relabel_sequential

        n = img.ndim
        axes = axes_check_and_normalize(axes, length=n)
        grid = self._axes_div_by(axes)
        axes_out = self._axes_out.replace('C','')
        shape_dict = dict(zip(axes,img.shape))
        shape_out = tuple(shape_dict[a] for a in axes_out)

        if context is None:
            context = self._axes_tile_overlap(axes)

        if np.isscalar(block_size):  block_size  = n*[block_size]
        if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
        if np.isscalar(context):     context     = n*[context]
        block_size, min_overlap, context = list(block_size), list(min_overlap), list(context)
        assert n == len(block_size) == len(min_overlap) == len(context)

        if 'C' in axes:
            # single block for channel axis
            i = axes_dict(axes)['C']
            # if (block_size[i], min_overlap[i], context[i]) != (None, None, None):
            #     print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " +
            #           "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True)
            block_size[i] = img.shape[i]
            min_overlap[i] = context[i] = 0

        block_size  = tuple(_grid_divisible(g, v, name='block_size',  verbose=False) for v,g,a in zip(block_size, grid,axes))
        min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
        context     = tuple(_grid_divisible(g, v, name='context',     verbose=False) for v,g,a in zip(context,    grid,axes))

        # print(f"input: shape {img.shape} with axes {axes}")
        print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)

        for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)):
            if c < o:
                print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True)

        # create block cover
        blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid)

        if np.isscalar(labels_out) and bool(labels_out) is False:
            labels_out = None
        else:
            if labels_out is None:
                labels_out = np.zeros(shape_out, dtype=labels_out_dtype)
            else:
                labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out})."))

        polys_all = {}
        # problem_ids = []
        label_offset = 1

        kwargs_override = dict(axes=axes, overlap_label=None)
        if show_progress:
            kwargs_override['show_tile_progress'] = False # disable progress for predict_instances
        for k,v in kwargs_override.items():
            if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True)
            kwargs[k] = v

        blocks = tqdm(blocks, disable=(not show_progress))
        # actual computation
        for block in blocks:
            labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs)
            labels = block.crop_context(labels, axes=axes_out)
            labels, polys = block.filter_objects(labels, polys, axes=axes_out)
            # TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset)
            labels = relabel_sequential(labels, label_offset)[0]
            # labels, fwd_map, _ = relabel_sequential(labels, label_offset)
            # if len(incomplete) > 0:
            #     problem_ids.extend([fwd_map[i] for i in incomplete])
            #     if show_progress:
            #         blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}")
            if labels_out is not None:
                block.write(labels_out, labels, axes=axes_out)
            for k,v in polys.items():
                polys_all.setdefault(k,[]).append(v)
            label_offset += len(polys['prob'])

        polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}

        # if labels_out is not None and len(problem_ids) > 0:
        #     # if show_progress:
        #     #     blocks.write('')
        #     # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True)
        #     repaint_labels(labels_out, problem_ids, polys_all, show_progress=False)

        return labels_out, polys_all#, tuple(problem_ids)
    def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.
        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        Y : :class:`numpy.ndarray`
            Array of target images.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of arrays for source and target validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.
        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.
        """

        ((isinstance(validation_data,
                     (list, tuple)) and len(validation_data) == 2) or
         _raise(ValueError('validation_data must be a pair of numpy arrays')))

        n_train, n_val = len(X), len(validation_data[0])
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.1f%% of all images)"
                % (100 * frac_val))
        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)

        for a, div_by in zip(axes, self._axes_div_by(axes)):
            n = X.shape[ax[a]]
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axis %s"
                    " (which has incompatible size %d)" % (div_by, a, n))

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        training_data = CryoDataWrapper(X, Y, self.config.train_batch_size)

        history = self.keras_model.fit_generator(
            generator=training_data,
            validation_data=validation_data,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            callbacks=self.callbacks,
            verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Example #11
0
    def train(self, X, validation_X, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        validation_x : :class:`numpy.ndarray`
            Array of validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """

        n_train, n_val = len(X), len(validation_X)
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn("small number of validation images (only %.1f%% of all images)" % (100*frac_val))
        axes = axes_check_and_normalize('S'+self.config.axes,X.ndim)
        ax = axes_dict(axes)
        div_by = 2**self.config.unet_n_depth
        axes_relevant = ''.join(a for a in 'XYZT' if a in axes)
        val_num_pix = 1
        train_num_pix = 1
        val_patch_shape = ()
        for a in axes_relevant:
            n = X.shape[ax[a]]
            val_num_pix *= validation_X.shape[ax[a]]
            train_num_pix *= X.shape[ax[a]]
            val_patch_shape += tuple([validation_X.shape[ax[a]]])
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axes %s"
                    " (axis %s has incompatible size %d)" % (div_by,axes_relevant,a,n)
                )

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        manipulator = eval('pm_{0}({1})'.format(self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius)))

        mean, std = float(self.config.mean), float(self.config.std)

        X = self.__normalize__(X, mean, std)
        validation_X = self.__normalize__(validation_X, mean, std)

        # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
        # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
        training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')),
                                                    self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix),
                                                    self.config.n2v_patch_shape, manipulator)

        # validation_Y is also validation_X plus a concatinated masking channel.
        # To speed things up, we precomupte the masking vo the validation data.
        validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C'))
        n2v_utils.manipulate_val_data(validation_X, validation_Y,
                                                        num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix),
                                                        shape=val_patch_shape,
                                                        value_manipulation=manipulator)

        history = self.keras_model.fit_generator(generator=training_data, validation_data=(validation_X, validation_Y),
                                                 epochs=epochs, steps_per_epoch=steps_per_epoch,
                                                 callbacks=self.callbacks, verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Example #12
0
def load_training_data_direct(X, Y, validation_split=0, axes=None, n_images=None, verbose=False):
    """Load training data from file in ``.npz`` format.

        The data file is expected to have the keys:

        - ``X``    : Array of training input images.
        - ``Y``    : Array of corresponding target images.
        - ``axes`` : Axes of the training images.


        Parameters
        ----------
        file : str
            File name
        validation_split : float
            Fraction of images to use as validation set during training.
        axes: str, optional
            Must be provided in case the loaded data does not contain ``axes`` information.
        n_images : int, optional
            Can be used to limit the number of images loaded from data.
        verbose : bool, optional
            Can be used to display information about the loaded images.

        Returns
        -------
        tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str )
            Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets
            and the axes of the input images.
            The tuple of validation data will be ``None`` if ``validation_split = 0``.

        """

    # f = np.load(file)
    # X, Y = f['X'], f['Y']
    # if axes is None:
    #    axes = f['axes']
    axes = axes_check_and_normalize(axes)

    assert X.shape == Y.shape
    assert len(axes) == X.ndim
    assert 'C' in axes
    if n_images is None:
        n_images = X.shape[0]
    assert X.shape[0] == Y.shape[0]
    assert 0 < n_images <= X.shape[0]
    assert 0 <= validation_split < 1

    X, Y = X[:n_images], Y[:n_images]
    channel = axes_dict(axes)['C']

    if validation_split > 0:
        n_val = int(round(n_images * validation_split))
        n_train = n_images - n_val
        assert 0 < n_val and 0 < n_train
        X_t, Y_t = X[-n_val:], Y[-n_val:]
        X, Y = X[:n_train], Y[:n_train]
        assert X.shape[0] == n_train and X_t.shape[0] == n_val
        X_t = move_channel_for_backend(X_t, channel=channel)
        Y_t = move_channel_for_backend(Y_t, channel=channel)

    X = move_channel_for_backend(X, channel=channel)
    Y = move_channel_for_backend(Y, channel=channel)

    axes = axes.replace('C', '')  # remove channel
    if backend_channels_last():
        axes = axes + 'C'
    else:
        axes = axes[:1] + 'C' + axes[1:]

    data_val = (X_t, Y_t) if validation_split > 0 else None

    if verbose:
        ax = axes_dict(axes)
        n_train, n_val = len(X), len(X_t) if validation_split > 0 else 0
        image_size = tuple(X.shape[ax[a]] for a in 'TZYX' if a in axes)
        n_dim = len(image_size)
        n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

        print('number of training images:\t', n_train)
        print('number of validation images:\t', n_val)
        print('image size (%dD):\t\t' % n_dim, image_size)
        print('axes:\t\t\t\t', axes)
        print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    return (X, Y), data_val, axes
Example #13
0
    def __init__(self, X,**kwargs):

        if  X.size != 0:

            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            if 'C' in axes:
                self.n_channel_in = X.shape[-1]
            else:
                self.n_channel_in = 1
            self.train_loss = 'demix'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 2
            self.unet_kern_size = 3
            self.unet_n_first = 64
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            # fixed parameters
            self.train_epochs = 200
            self.train_steps_per_epoch = 50
            self.train_learning_rate = 0.0004
            self.train_batch_size = 64
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = False
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (128, 128) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5

            self.single_net_per_channel = False


            self.channel_denoised = False
            self.multi_objective = True
            self.normalizer = 'none'
            self.weights_objectives = [0.1, 0, 0.45, 0.45]
            self.distributions = 'gauss'
            self.n2v_leave_center = False
            self.scale_aug = False
            self.structN2Vmask = None

            self.n_back_modes = 2
            self.n_fore_modes = 2
            self.n_instance_seg = 0
            self.n_back_i_modes = 1
            self.n_fore_i_modes = 1

            self.fit_std = False
            self.fit_mean = True



            # self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components +\
            #                      self.n_instance_seg * (self.n_back_i_modes+self.n_fore_i_modes)

        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        # print('KWARGS')
        for k in kwargs:

            # print(k,  kwargs[k])
            setattr(self, k, kwargs[k])
        self.n_components = 3 if (self.fit_std & self.fit_mean) else 2
        self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components + \
                             self.n_instance_seg * (self.n_back_i_modes + self.n_fore_i_modes)
Example #14
0
    def __init__(self, X, **kwargs):
        """See class docstring"""
        assert X.size != 0
        assert len(X.shape) == 4 or len(
            X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(
            ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(
            ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(
            ValueError('sample axis S must be first.'))
        axes = axes.replace('S', '')  # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(
                    ValueError('channel axis must be last for backend (%s).' %
                               K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(
                    ValueError('channel axis must be first for backend (%s).' %
                               K.backend()))
            else:
                axes = 'C' + axes

        assert X.shape[-1] == 1
        n_channel_in = 1
        n_channel_out = 3

        # directly set by parameters
        self.n_dim = n_dim
        self.axes = axes
        # fixed parameters
        self.n_channel_in = n_channel_in
        self.n_channel_out = n_channel_out
        self.train_loss = 'seg'

        # default config (can be overwritten by kwargs below)

        self.unet_n_depth = 4
        self.relative_weights = [1.0, 1.0, 5.0]
        self.unet_kern_size = 3
        self.unet_n_first = 32
        self.unet_last_activation = 'linear'
        self.probabilistic = False
        self.unet_residual = False
        if backend_channels_last():
            self.unet_input_shape = self.n_dim * (None, ) + (
                self.n_channel_in, )
        else:
            self.unet_input_shape = (
                self.n_channel_in, ) + self.n_dim * (None, )

        self.train_epochs = 200
        self.train_steps_per_epoch = 400
        self.train_learning_rate = 0.0004
        self.train_batch_size = 128
        self.train_tensorboard = False
        self.train_checkpoint = 'weights_best.h5'
        self.train_checkpoint_last = 'weights_last.h5'
        self.train_checkpoint_epoch = 'weights_now.h5'
        self.train_reduce_lr = {'factor': 0.5, 'patience': 10}
        self.batch_norm = True

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_in' manually
        try:
            del kwargs['n_channel_in']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_out' manually
        try:
            del kwargs['n_channel_out']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'train_loss' manually
        try:
            del kwargs['train_loss']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'probabilistic' manually
        try:
            del kwargs['probabilistic']
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            del kwargs['unet_residual']
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Example #15
0
    def plugin(
        viewer: napari.Viewer,
        label_head,
        image: napari.layers.Image,
        axes,
        label_nn,
        model_type,
        model2d,
        model3d,
        model_folder,
        model_axes,
        norm_image,
        perc_low,
        perc_high,
        input_scale,
        label_nms,
        prob_thresh,
        nms_thresh,
        output_type,
        label_adv,
        n_tiles,
        norm_axes,
        timelapse_opts,
        cnn_output,
        set_thresholds,
        defaults_button,
        progress_bar: mw.ProgressBar,
    ) -> List[napari.types.LayerDataTuple]:

        model = get_model(*model_selected)
        if model._is_multiclass():
            warn(
                "multi-class mode not supported yet, ignoring classification output"
            )

        lkwargs = {}
        x = get_data(image)
        axes = axes_check_and_normalize(axes, length=x.ndim)

        if not (input_scale is None
                or isinstance(input_scale, numbers.Number)):
            input_scale = tuple(s for a, s in zip(axes, input_scale)
                                if a not in ("T", ))
            # print(f'scaling by {input_scale}')

        if not axes.replace("T", "").startswith(
                model._axes_out.replace("C", "")):
            warn(
                f"output images have different axes ({model._axes_out.replace('C','')}) than input image ({axes})"
            )
            # TODO: adjust image.scale according to shuffled axes

        if norm_image:
            axes_norm = axes_check_and_normalize(norm_axes)
            axes_norm = "".join(set(axes_norm).intersection(
                set(axes)))  # relevant axes present in input image
            assert len(axes_norm) > 0
            # always jointly normalize channels for RGB images
            if ("C" in axes and image.rgb == True) and ("C" not in axes_norm):
                axes_norm = axes_norm + "C"
                warn("jointly normalizing channels of RGB input image")
            ax = axes_dict(axes)
            _axis = tuple(sorted(ax[a] for a in axes_norm))
            # # TODO: address joint vs. channel/time-separate normalization properly (let user choose)
            # #       also needs to be documented somewhere
            # if 'T' in axes:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly, frames independently
            #          _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],))
            #     else:
            #         # normalize channels independently, frames independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],ax['C']))
            # else:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly
            #         _axis = None
            #     else:
            #         # normalize channels independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],))
            x = normalize(x, perc_low, perc_high, axis=_axis)

        # TODO: progress bar (labels) often don't show up. events not processed?
        if "T" in axes:
            app = use_app()
            t = axes_dict(axes)["T"]
            n_frames = x.shape[t]
            if n_tiles is not None:
                # remove tiling value for time axis
                n_tiles = tuple(v for i, v in enumerate(n_tiles) if i != t)

            def progress(it, **kwargs):
                progress_bar.label = "StarDist Prediction (frames)"
                progress_bar.range = (0, n_frames)
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                app.process_events()

        elif n_tiles is not None and np.prod(n_tiles) > 1:
            n_tiles = tuple(n_tiles)
            app = use_app()

            def progress(it, **kwargs):
                progress_bar.label = "CNN Prediction (tiles)"
                progress_bar.range = (0, kwargs.get("total", 0))
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                #
                progress_bar.label = "NMS Postprocessing"
                progress_bar.range = (0, 0)
                app.process_events()

        else:
            progress = False
            progress_bar.label = "StarDist Prediction"
            progress_bar.range = (0, 0)
            progress_bar.show()
            use_app().process_events()

        # semantic output axes of predictions
        assert model._axes_out[-1] == "C"
        axes_out = list(model._axes_out[:-1])

        if "T" in axes:
            x_reorder = np.moveaxis(x, t, 0)
            axes_reorder = axes.replace("T", "")
            axes_out.insert(t, "T")
            res = tuple(
                zip(*tuple(
                    model.predict_instances(
                        _x,
                        axes=axes_reorder,
                        prob_thresh=prob_thresh,
                        nms_thresh=nms_thresh,
                        n_tiles=n_tiles,
                        scale=input_scale,
                        sparse=(not cnn_output),
                        return_predict=cnn_output,
                    ) for _x in progress(x_reorder))))

            if cnn_output:
                labels, polys = tuple(zip(*res[0]))
                cnn_output = tuple(np.stack(c, t) for c in tuple(zip(*res[1])))
            else:
                labels, polys = res

            labels = np.asarray(labels)

            if len(polys) > 1:
                if timelapse_opts == TimelapseLabels.Match.value:
                    # match labels in consecutive frames (-> simple IoU tracking)
                    labels = group_matching_labels(labels)
                elif timelapse_opts == TimelapseLabels.Unique.value:
                    # make label ids unique (shift by offset)
                    offsets = np.cumsum([len(p["points"]) for p in polys])
                    for y, off in zip(labels[1:], offsets):
                        y[y > 0] += off
                elif timelapse_opts == TimelapseLabels.Separate.value:
                    # each frame processed separately (nothing to do)
                    pass
                else:
                    raise NotImplementedError(
                        f"unknown option '{timelapse_opts}' for time-lapse labels"
                    )

            labels = np.moveaxis(labels, 0, t)

            if isinstance(model, StarDist3D):
                # TODO poly output support for 3D timelapse
                polys = None
            else:
                polys = dict(
                    coord=np.concatenate(
                        tuple(
                            np.insert(p["coord"], t, _t, axis=-2)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                    points=np.concatenate(
                        tuple(
                            np.insert(p["points"], t, _t, axis=-1)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                )

            if cnn_output:
                pred = (labels, polys), cnn_output
            else:
                pred = labels, polys

        else:
            # TODO: possible to run this in a way that it can be canceled?
            pred = model.predict_instances(
                x,
                axes=axes,
                prob_thresh=prob_thresh,
                nms_thresh=nms_thresh,
                n_tiles=n_tiles,
                show_tile_progress=progress,
                scale=input_scale,
                sparse=(not cnn_output),
                return_predict=cnn_output,
            )
        progress_bar.hide()

        # determine scale for output axes
        scale_in_dict = dict(zip(axes, image.scale))
        scale_out = [scale_in_dict.get(a, 1.0) for a in axes_out]

        layers = []
        if cnn_output:
            (labels, polys), cnn_out = pred
            prob, dist = cnn_out[:2]
            dist = np.moveaxis(dist, -1, 0)

            assert len(model.config.grid) == len(model.config.axes) - 1
            grid_dict = dict(
                zip(model.config.axes.replace("C", ""), model.config.grid))
            # scale output axes to match input axes
            _scale = [
                s * grid_dict.get(a, 1) for a, s in zip(axes_out, scale_out)
            ]
            # small translation correction if grid > 1 (since napari centers objects)
            _translate = [0.5 * (grid_dict.get(a, 1) - 1) for a in axes_out]

            layers.append((
                dist,
                dict(
                    name="StarDist distances",
                    scale=[1] + _scale,
                    translate=[0] + _translate,
                    **lkwargs,
                ),
                "image",
            ))
            layers.append((
                prob,
                dict(
                    name="StarDist probability",
                    scale=_scale,
                    translate=_translate,
                    **lkwargs,
                ),
                "image",
            ))
        else:
            labels, polys = pred

        if output_type in (Output.Labels.value, Output.Both.value):
            layers.append((
                labels,
                dict(name="StarDist labels",
                     scale=scale_out,
                     opacity=0.5,
                     **lkwargs),
                "labels",
            ))
        if output_type in (Output.Polys.value, Output.Both.value):
            n_objects = len(polys["points"])
            if isinstance(model, StarDist3D):
                surface = surface_from_polys(polys)
                layers.append((
                    surface,
                    dict(
                        name="StarDist polyhedra",
                        contrast_limits=(0, surface[-1].max()),
                        scale=scale_out,
                        colormap=label_colormap(n_objects),
                        **lkwargs,
                    ),
                    "surface",
                ))
            else:
                # TODO: sometimes hangs for long time (indefinitely?) when returning many polygons (?)
                #       seems to be a known issue: https://github.com/napari/napari/issues/2015
                # TODO: coordinates correct or need offset (0.5 or so)?
                shapes = np.moveaxis(polys["coord"], -1, -2)
                layers.append((
                    shapes,
                    dict(
                        name="StarDist polygons",
                        shape_type="polygon",
                        scale=scale_out,
                        edge_width=0.75,
                        edge_color="yellow",
                        face_color=[0, 0, 0, 0],
                        **lkwargs,
                    ),
                    "shapes",
                ))
        return layers
Example #16
0
    def is_valid(self, return_invalid=False):
        """Check if configuration is valid.

        Returns
        -------
        bool
            Flag that indicates whether the current configuration values are valid.
        """

        def _is_int(v, low=None, high=None):
            return (
                    isinstance(v, int) and
                    (True if low is None else low <= v) and
                    (True if high is None else v <= high)
            )

        ok = {}
        ok['n_dim'] = self.n_dim in (2, 3)
        try:
            axes_check_and_normalize(self.axes, self.n_dim + 1, disallowed='S')
            ok['axes'] = True
        except:
            ok['axes'] = False
        ok['n_channel_in'] = _is_int(self.n_channel_in, 1)
        ok['n_channel_out'] = _is_int(self.n_channel_out, 4)
        ok['train_loss'] = (
            (self.train_loss in ('seg', 'denoiseg'))
        )
        ok['unet_n_depth'] = _is_int(self.unet_n_depth, 1)
        ok['relative_weights'] = isinstance(self.relative_weights, list) and len(self.relative_weights) == 3 and all(
            x > 0 for x in self.relative_weights)
        ok['unet_kern_size'] = _is_int(self.unet_kern_size, 1)
        ok['unet_n_first'] = _is_int(self.unet_n_first, 1)
        ok['unet_last_activation'] = self.unet_last_activation in ('linear', 'relu')
        ok['probabilistic'] = isinstance(self.probabilistic, bool) and not self.probabilistic
        ok['unet_residual'] = isinstance(self.unet_residual, bool) and not self.unet_residual
        ok['unet_input_shape'] = (
                isinstance(self.unet_input_shape, (list, tuple)) and
                len(self.unet_input_shape) == self.n_dim + 1 and
                self.unet_input_shape[-1] == self.n_channel_in and
                all((d is None or (_is_int(d) and d % (2 ** self.unet_n_depth) == 0) for d in
                     self.unet_input_shape[:-1]))
        )
        ok['train_epochs'] = _is_int(self.train_epochs, 1)
        ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch, 1)
        ok['train_learning_rate'] = np.isscalar(self.train_learning_rate) and self.train_learning_rate > 0
        ok['train_batch_size'] = _is_int(self.train_batch_size, 1)
        ok['train_tensorboard'] = isinstance(self.train_tensorboard, bool)
        ok['train_checkpoint'] = self.train_checkpoint is None or isinstance(self.train_checkpoint, string_types)
        ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance(self.train_reduce_lr, dict) and self.train_reduce_lr['monitor'] in ['val_loss', 'val_seg_loss', 'val_denoise_loss']
        ok['batch_norm'] = isinstance(self.batch_norm, bool)
        ok['n2v_perc_pix'] = self.n2v_perc_pix > 0 and self.n2v_perc_pix <= 100
        ok['n2v_patch_shape'] = (
                isinstance(self.n2v_patch_shape, (list, tuple)) and
                len(self.n2v_patch_shape) == self.n_dim and
                all(d > 0 for d in self.n2v_patch_shape)
        )
        ok['n2v_manipulator'] = self.n2v_manipulator in ['normal_withoutCP', 'uniform_withCP', 'normal_additive',
                                                         'normal_fitted', 'identity']
        ok['n2v_neighborhood_radius'] = _is_int(self.n2v_neighborhood_radius, 0)
        ok['denoiseg_alpha'] = isinstance(self.denoiseg_alpha, float) and self.denoiseg_alpha >= 0.0 and self.denoiseg_alpha <= 1.0

        if return_invalid:
            return all(ok.values()), tuple(k for (k, v) in ok.items() if not v)
        else:
            return all(ok.values())
Example #17
0
    def __init__(self, X, **kwargs):
        """See class docstring"""

        # X is empty if config is None
        if  X.size != 0:
            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            self.n_channel_in = 1
            self.n_channel_out = 4
            self.train_loss = 'denoiseg'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 4
            self.relative_weights = [1.0, 1.0, 5.0]
            self.unet_kern_size = 3
            self.unet_n_first = 32
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            self.train_epochs = 200
            self.train_steps_per_epoch = 400
            self.train_learning_rate = 0.0004
            self.train_batch_size = 128
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = True
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (64, 64) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5
            self.denoiseg_alpha = 0.5

        # disallow setting 'probabilistic' manually
        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Example #18
0
    def is_valid(self, return_invalid=False):
        """Check if configuration is valid.

        Returns
        -------
        bool
            Flag that indicates whether the current configuration values are valid.
        """
        def _is_int(v, low=None, high=None):
            return (isinstance(v, int) and (True if low is None else low <= v)
                    and (True if high is None else v <= high))

        ok = {}
        ok['means'] = True
        for mean in self.means:
            ok['means'] &= np.isscalar(float(mean))
        ok['stds'] = True
        for std in self.stds:
            ok['stds'] &= np.isscalar(float(std)) and float(std) > 0.0
        ok['n_dim'] = self.n_dim in (2, 3)
        try:
            axes_check_and_normalize(self.axes, self.n_dim + 1, disallowed='S')
            ok['axes'] = True
        except:
            ok['axes'] = False
        ok['n_channel_in'] = _is_int(self.n_channel_in, 1)
        ok['n_channel_out'] = _is_int(self.n_channel_out, 1)

        ok['unet_residual'] = (isinstance(self.unet_residual, bool)
                               and (not self.unet_residual or
                                    (self.n_channel_in == self.n_channel_out)))
        ok['unet_n_depth'] = _is_int(self.unet_n_depth, 1)
        ok['unet_kern_size'] = _is_int(self.unet_kern_size, 1)
        ok['unet_n_first'] = _is_int(self.unet_n_first, 1)
        ok['unet_last_activation'] = self.unet_last_activation in ('linear',
                                                                   'relu')
        ok['unet_input_shape'] = (
            isinstance(self.unet_input_shape, (list, tuple))
            and len(self.unet_input_shape) == self.n_dim + 1
            and self.unet_input_shape[-1] == self.n_channel_in and all(
                (d is None or (_is_int(d) and d % (2**self.unet_n_depth) == 0)
                 for d in self.unet_input_shape[:-1])))
        ok['train_loss'] = ((self.train_loss in ('mse', 'mae')))
        ok['train_epochs'] = _is_int(self.train_epochs, 1)
        ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch, 1)
        ok['train_learning_rate'] = np.isscalar(
            self.train_learning_rate) and self.train_learning_rate > 0
        ok['train_batch_size'] = _is_int(self.train_batch_size, 1)
        ok['train_tensorboard'] = isinstance(self.train_tensorboard, bool)
        ok['train_checkpoint'] = self.train_checkpoint is None or isinstance(
            self.train_checkpoint, string_types)
        ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance(
            self.train_reduce_lr, dict)
        ok['batch_norm'] = isinstance(self.batch_norm, bool)
        ok['n2v_perc_pix'] = self.n2v_perc_pix > 0 and self.n2v_perc_pix <= 100
        ok['n2v_patch_shape'] = (isinstance(self.n2v_patch_shape,
                                            (list, tuple))
                                 and len(self.n2v_patch_shape) == self.n_dim
                                 and all(d > 0 for d in self.n2v_patch_shape))
        ok['n2v_manipulator'] = self.n2v_manipulator in [
            'normal_withoutCP', 'uniform_withCP', 'normal_additive',
            'normal_fitted', 'identity'
        ]
        ok['n2v_neighborhood_radius'] = _is_int(self.n2v_neighborhood_radius,
                                                0)
        ok['single_net_per_channel'] = isinstance(self.single_net_per_channel,
                                                  bool)

        if self.structN2Vmask is None:
            ok['structN2Vmask'] = True
        else:
            mask = np.array(self.structN2Vmask)
            t1 = mask.ndim == self.n_dim
            t2 = all(x % 2 == 1 for x in mask.shape)
            t3 = all([x in [0, 1] for x in mask.flat])
            ok['structN2Vmask'] = t1 and t2 and t3

        if return_invalid:
            return all(ok.values()), tuple(k for (k, v) in ok.items() if not v)
        else:
            return all(ok.values())
Example #19
0
    def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None):
        n_train, n_val = len(X), len(validation_data[0])
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.05f%% of all images)"
                % (100 * frac_val))
        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)
        div_by = 2**self.config.unet_n_depth
        axes_relevant = ''.join(a for a in 'XYZT' if a in axes)
        val_num_pix = 1
        train_num_pix = 1
        val_patch_shape = ()
        for a in axes_relevant:
            n = X.shape[ax[a]]
            val_num_pix *= validation_data[0].shape[ax[a]]
            train_num_pix *= X.shape[ax[a]]
            val_patch_shape += tuple([validation_data[0].shape[ax[a]]])
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axes %s"
                    " (axis %s has incompatible size %d)" %
                    (div_by, axes_relevant, a, n))

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        manipulator = eval('pm_{0}({1})'.format(
            self.config.n2v_manipulator,
            str(self.config.n2v_neighborhood_radius)))

        means = np.array([float(mean) for mean in self.config.means],
                         ndmin=len(X.shape),
                         dtype=np.float32)
        stds = np.array([float(std) for std in self.config.stds],
                        ndmin=len(X.shape),
                        dtype=np.float32)

        X = self.__normalize__(X, means, stds)
        validation_X = self.__normalize__(validation_data[0], means, stds)

        # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
        # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
        training_data = DenoiSeg_DataWrapper(
            X=X,
            n2v_Y=np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)),
                                 axis=axes.index('C')),
            seg_Y=Y,
            batch_size=self.config.train_batch_size,
            perc_pix=self.config.n2v_perc_pix,
            shape=self.config.n2v_patch_shape,
            value_manipulation=manipulator)

        # validation_Y is also validation_X plus a concatenated masking channel.
        # To speed things up, we precompute the masking vo the validation data.
        validation_Y = np.concatenate(
            (validation_X,
             np.zeros(validation_X.shape, dtype=validation_X.dtype)),
            axis=axes.index('C'))
        n2v_utils.manipulate_val_data(validation_X,
                                      validation_Y,
                                      perc_pix=self.config.n2v_perc_pix,
                                      shape=val_patch_shape,
                                      value_manipulation=manipulator)

        validation_Y = np.concatenate((validation_Y, validation_data[1]),
                                      axis=-1)

        history = self.keras_model.fit(training_data,
                                       validation_data=(validation_X,
                                                        validation_Y),
                                       epochs=epochs,
                                       steps_per_epoch=steps_per_epoch,
                                       callbacks=self.callbacks,
                                       verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Example #20
0
def main():
    if not ('__file__' in locals() or '__file__' in globals()):
        print('running interactively, exiting.')
        sys.exit(0)

    # parse arguments
    parser, args = parse_args()
    args_dict = vars(args)

    # exit and show help if no arguments provided at all
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    # check for required arguments manually (because of argparse issue)
    required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax',
                '--model-basedir', '--model-name', '--output-dir')
    for r in required:
        dest = r[2:].replace('-', '_')
        if args_dict[dest] is None:
            parser.print_usage(file=sys.stderr)
            print("%s: error: the following arguments are required: %s" %
                  (parser.prog, r),
                  file=sys.stderr)
            sys.exit(1)

    # show effective arguments (including defaults)
    if not args.quiet:
        print('Arguments')
        print('---------')
        pprint(args_dict)
        print()
        sys.stdout.flush()

    # logging function
    log = (lambda *a, **k: None) if args.quiet else tqdm.write

    # get list of input files and exit if there are none
    file_list = list(Path(args.input_dir).glob(args.input_pattern))
    if len(file_list) == 0:
        log("No files to process in '%s' with pattern '%s'." %
            (args.input_dir, args.input_pattern))
        sys.exit(0)

    # delay imports after checking to all required arguments are provided
    from tifffile import imread, imsave
    from csbdeep.utils.tf import keras_import
    K = keras_import('backend')
    from csbdeep.models import CARE
    from csbdeep.data import PercentileNormalizer
    sys.stdout.flush()
    sys.stderr.flush()

    # limit gpu memory
    if args.gpu_memory_limit is not None:
        from csbdeep.utils.tf import limit_gpu_memory
        limit_gpu_memory(args.gpu_memory_limit)

    # create CARE model and load weights, create normalizer
    K.clear_session()
    model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
    if args.model_weights is not None:
        print("Loading network weights from '%s'." % args.model_weights)
        model.load_weights(args.model_weights)
    normalizer = PercentileNormalizer(pmin=args.norm_pmin,
                                      pmax=args.norm_pmax,
                                      do_after=args.norm_undo)

    n_tiles = args.n_tiles
    if n_tiles is not None and len(n_tiles) == 1:
        n_tiles = n_tiles[0]

    processed = []

    # process all files
    for file_in in tqdm(file_list,
                        disable=args.quiet
                        or (n_tiles is not None and np.prod(n_tiles) > 1)):
        # construct output file name
        file_out = Path(args.output_dir) / args.output_name.format(
            file_path=str(file_in.relative_to(args.input_dir).parent),
            file_name=file_in.stem,
            file_ext=file_in.suffix,
            model_name=args.model_name,
            model_weights=Path(args.model_weights).stem
            if args.model_weights is not None else None)

        # checks
        (file_in.suffix.lower() in ('.tif', '.tiff')
         and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise(
             ValueError('only tiff files supported.'))

        # load and predict restored image
        img = imread(str(file_in))
        restored = model.predict(img,
                                 axes=args.input_axes,
                                 normalizer=normalizer,
                                 n_tiles=n_tiles)

        # restored image could be multi-channel even if input image is not
        axes_out = axes_check_and_normalize(args.input_axes)
        if restored.ndim > img.ndim:
            assert restored.ndim == img.ndim + 1
            assert 'C' not in axes_out
            axes_out += 'C'

        # convert data type (if necessary)
        restored = restored.astype(np.dtype(args.output_dtype), copy=False)

        # save to disk
        if not args.dry_run:
            file_out.parent.mkdir(parents=True, exist_ok=True)
            if args.imagej_tiff:
                save_tiff_imagej_compatible(str(file_out), restored, axes_out)
            else:
                imsave(str(file_out), restored)

        processed.append((file_in, file_out))

    # print summary of processed files
    if not args.quiet:
        sys.stdout.flush()
        sys.stderr.flush()
        n_processed = len(processed)
        len_processed = len(str(n_processed))
        log('Finished processing %d %s' %
            (n_processed, 'files' if n_processed > 1 else 'file'))
        log('-' * (26 + len_processed if n_processed > 1 else 26))
        for i, (file_in, file_out) in enumerate(processed):
            len_file = max(len(str(file_in)), len(str(file_out)))
            log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format(
                1 + i, str(file_in)))
            log(('{:>%d}  out: {:>%d}' % (len_processed, len_file)).format(
                '', str(file_out)))
Example #21
0
    def predict(self,
                img,
                axes=None,
                normalizer=None,
                n_tiles=None,
                show_tile_progress=True,
                **predict_kwargs):
        """Predict.

        Parameters
        ----------
        img : :class:`numpy.ndarray`
            Input image
        axes : str or None
            Axes of the input ``img``.
            ``None`` denotes that axes of img are the same as denoted in the config.
        normalizer : :class:`csbdeep.data.Normalizer` or None
            (Optional) normalization of input image before prediction.
            Note that the default (``None``) assumes ``img`` to be already normalized.
        n_tiles : iterable or None
            Out of memory (OOM) errors can occur if the input image is too large.
            To avoid this problem, the input image is broken up into (overlapping) tiles
            that are processed independently and re-assembled.
            This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
            ``None`` denotes that no tiling should be used.
        show_tile_progress: bool
            Whether to show progress during tiled prediction.
        predict_kwargs: dict
            Keyword arguments for ``predict`` function of Keras model.

        Returns
        -------
        (:class:`numpy.ndarray`,:class:`numpy.ndarray`)
            Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances.

        """
        if n_tiles is None:
            n_tiles = [1] * img.ndim
        try:
            n_tiles = tuple(n_tiles)
            img.ndim == len(n_tiles) or _raise(TypeError())
        except TypeError:
            raise ValueError("n_tiles must be an iterable of length %d" %
                             img.ndim)
        all(np.isscalar(t) and 1 <= t and int(t) == t
            for t in n_tiles) or _raise(
                ValueError(
                    "all values of n_tiles must be integer values >= 1"))
        n_tiles = tuple(map(int, n_tiles))

        if axes is None:
            axes = self.config.axes
            assert 'C' in axes
            if img.ndim == len(axes) - 1 and self.config.n_channel_in == 1:
                # img has no dedicated channel axis, but 'C' always part of config axes
                axes = axes.replace('C', '')

        axes = axes_check_and_normalize(axes, img.ndim)
        axes_net = self.config.axes

        _permute_axes = self._make_permute_axes(axes, axes_net)
        x = _permute_axes(img)  # x has axes_net semantics

        channel = axes_dict(axes_net)['C']
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
        axes_net_div_by = self._axes_div_by(axes_net)

        grid = tuple(self.config.grid)
        len(grid) == len(axes_net) - 1 or _raise(ValueError())
        grid_dict = dict(zip(axes_net.replace('C', ''), grid))

        normalizer = self._check_normalizer_resizer(normalizer, None)[0]
        resizer = StarDistPadAndCropResizer(grid=grid_dict)

        x = normalizer.before(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)

        def predict_direct(tile):
            sh = list(tile.shape)
            sh[channel] = 1
            dummy = np.empty(sh, np.float32)
            prob, dist = self.keras_model.predict(
                [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs)
            return prob[0], dist[0]

        if np.prod(n_tiles) > 1:
            tiling_axes = axes_net.replace('C', '')  # axes eligible for tiling
            x_tiling_axis = tuple(
                axes_dict(axes_net)[a]
                for a in tiling_axes)  # numerical axis ids for x
            axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
            # hack: permute tiling axis in the same way as img -> x was permuted
            n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape
            (all(n_tiles[i] == 1
                 for i in range(x.ndim) if i not in x_tiling_axis)
             or _raise(
                 ValueError("entry of n_tiles > 1 only allowed for axes '%s'" %
                            tiling_axes)))

            sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)]
            sh[channel] = 1
            prob = np.empty(sh, np.float32)
            sh[channel] = self.config.n_rays
            dist = np.empty(sh, np.float32)

            n_block_overlaps = [
                int(np.ceil(overlap / blocksize)) for overlap, blocksize in
                zip(axes_net_tile_overlaps, axes_net_div_by)
            ]

            for tile, s_src, s_dst in tqdm(tile_iterator(
                    x,
                    n_tiles,
                    block_sizes=axes_net_div_by,
                    n_block_overlaps=n_block_overlaps),
                                           disable=(not show_tile_progress),
                                           total=np.prod(n_tiles)):
                prob_tile, dist_tile = predict_direct(tile)
                # account for grid
                s_src = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_src, axes_net)
                ]
                s_dst = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_dst, axes_net)
                ]
                # prob and dist have different channel dimensionality than image x
                s_src[channel] = slice(None)
                s_dst[channel] = slice(None)
                s_src, s_dst = tuple(s_src), tuple(s_dst)
                # print(s_src,s_dst)
                prob[s_dst] = prob_tile[s_src]
                dist[s_dst] = dist_tile[s_src]

        else:
            prob, dist = predict_direct(x)

        prob = resizer.after(prob, axes_net)
        dist = resizer.after(dist, axes_net)
        dist = np.maximum(
            1e-3, dist
        )  # avoid small/negative dist values to prevent problems with Qhull

        prob = np.take(prob, 0, axis=channel)
        dist = np.moveaxis(dist, channel, -1)

        return prob, dist
Example #22
0
    def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.
        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        Y : :class:`numpy.ndarray`
            Array of target images.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of arrays for source and target validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.
        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.
        """
        leave_center = self.config.n2v_leave_center
        scale_augmentation = self.config.scale_aug

        ## Resize validation if necessary....
        print((np.sum(
            np.abs(
                np.array(validation_data[0][1:-1]) -
                np.array(self.config.n2v_patch_shape))) != 0))
        if (np.sum(
                np.abs(
                    np.array(validation_data[0][1:-1]) -
                    np.array(self.config.n2v_patch_shape))) != 0):
            X_val = subpatch_2D(validation_data[0],
                                np.array(self.config.n2v_patch_shape))
            Y_val = subpatch_2D(validation_data[1],
                                np.array(self.config.n2v_patch_shape))
            validation_data = (X_val, Y_val)

        ((isinstance(validation_data,
                     (list, tuple)) and len(validation_data) == 2) or
         _raise(ValueError('validation_data must be a pair of numpy arrays')))

        n_train, n_val = len(X), len(validation_data[0])

        ## Warning about validation size
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.1f%% of all images)"
                % (100 * frac_val))

        #axes description
        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)

        # for a,div_by in zip(axes,self._axes_div_by(axes)):
        #     n = X.shape[ax[a]]
        #     if n % div_by != 0:
        #         raise ValueError(
        #             "training images must be evenly divisible by %d along axis %s"
        #             " (which has incompatible size %d)" % (div_by,a,n)
        #         )

        ## ToDO: what is this??
        div_by = 2**self.config.unet_n_depth
        axes_relevant = ''.join(a for a in 'XYZT' if a in axes)
        val_num_pix = 1
        train_num_pix = 1
        val_patch_shape = ()
        for a in axes_relevant:
            n = X.shape[ax[a]]
            val_num_pix *= validation_data[0].shape[ax[a]]
            train_num_pix *= X.shape[ax[a]]
            val_patch_shape += tuple([validation_data[0].shape[ax[a]]])
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axes %s"
                    " (axis %s has incompatible size %d)" %
                    (div_by, axes_relevant, a, n))

        # epochs & steps per epochs
        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        # if (self.config.train_tensorboard and self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks)):
        #     self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
        #                                                log_dir=str(self.logdir/'logs'/'images'),
        #                                                n_images=3, prob_out=self.config.probabilistic))
        #
        # training_data = DataWrapper(X, Y, self.config.train_batch_size,epochs*steps_per_epoch)

        manipulator = eval('pm_{0}({1})'.format(
            self.config.n2v_manipulator,
            str(self.config.n2v_neighborhood_radius)))

        if self.config.normalizer is 'std':
            means = np.array([float(mean) for mean in self.config.means],
                             ndmin=len(X.shape),
                             dtype=np.float32)
            stds = np.array([float(std) for std in self.config.stds],
                            ndmin=len(X.shape),
                            dtype=np.float32)

            X = self.__normalize__(X, means, stds)
            validation_X = self.__normalize__(validation_data[0], means, stds)
        else:
            validation_X = validation_data[0]
        # Todo: validation normalization if we have; also pick type of normalization as an option

        #mask (struct to inpaint)
        _mask = np.array(
            self.config.structN2Vmask) if self.config.structN2Vmask else None
        # print(_mask,self.config.channel_denoised)
        training_data = BioSeg_DataWrapper(
            X,
            Y,
            self.config.train_batch_size,
            self.config.n2v_perc_pix,
            self.config.n2v_patch_shape,
            manipulator,
            structN2Vmask=_mask,
            chan_denoise=self.config.channel_denoised,
            multiple_objectives=self.config.multi_objective,
            leave_center=leave_center,
            scale_augmentation=scale_augmentation)

        # validation_Y is also validation_X plus a concatenated masking channel.
        # To speed things up, we precompute the masking vo the validation data.

        if not self.config.channel_denoised:
            validation_Y = np.concatenate(
                (validation_X,
                 np.zeros(validation_X.shape, dtype=validation_X.dtype)),
                axis=axes.index('C'))
        else:
            val_aux = validation_data[1][..., 0:X.shape[-1]]
            # print(val_aux.shape)
            # if X.shape[-1] == 1:
            #     val_aux = val_aux[...,np.newaxis]
            validation_Y = np.concatenate(
                (val_aux, np.zeros(val_aux.shape, dtype=validation_X.dtype)),
                axis=axes.index('C'))

        # print(validation_Y.shape, validation_X.shape)

        manipulate_val_data(validation_X,
                            validation_Y,
                            perc_pix=self.config.n2v_perc_pix,
                            shape=val_patch_shape,
                            value_manipulation=manipulator,
                            chan_denoise=self.config.channel_denoised)

        # print(self.config)
        # print(self.config.multi_objective)

        if self.config.multi_objective:
            if (self.config.channel_denoised) and (
                    validation_data[1].shape[-1] >
                    X.shape[-1]):  #additional channels
                validation_Y = np.concatenate(
                    (validation_Y, validation_data[1][..., X.shape[-1]:]),
                    axis=-1)

            if not self.config.channel_denoised:
                validation_Y = np.concatenate(
                    (validation_Y, validation_data[1][..., :]), axis=-1)

        # print(validation_Y.shape, validation_X.shape)
        fit = self.keras_model.fit_generator

        # fit = self.keras_model.fit

        history = fit(training_data,
                      validation_data=(validation_X, validation_Y),
                      epochs=epochs,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=self.callbacks,
                      verbose=1)

        ## ToDo : what does this save do
        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        #self._training_finished()

        return history