Esempio n. 1
0
 def _create(red_none, img_size, img_axes, patch_size, patch_axes):
     raw_data, red_axes, keepdims = get_data(n_images, img_axes, img_size)
     # change patch_size to (img_size or None) for red_axes
     patch_size = list(patch_size)
     for a in red_axes:
         patch_size[axes_dict(
             img_axes if patch_axes is None else patch_axes)[a]] = (
                 None if red_none else img_size[axes_dict(img_axes)[a]])
     X, Y, XYaxes = create_patches_reduced_target(
         raw_data=raw_data,
         patch_size=patch_size,
         patch_axes=patch_axes,
         n_patches_per_image=n_patches_per_image,
         reduction_axes=red_axes,
         target_axes=rng.choice(
             (None,
              img_axes)) if keepdims else ''.join(a for a in img_axes
                                                  if a not in red_axes),
         #
         normalization=lambda patches_x, patches_y, *args:
         (patches_x, patches_y),
         verbose=False,
     )
     assert len(X) == n_images * n_patches_per_image
     _X = np.mean(X,
                  axis=tuple(axes_dict(XYaxes)[a] for a in red_axes),
                  keepdims=True)
     err = np.max(np.abs(_X - Y))
     assert err < 1e-5
Esempio n. 2
0
 def _guess_n_tiles(self, img):
     axes = self._normalize_axes(img, axes=None)
     shape = list(img.shape)
     if 'C' in axes:
         del shape[axes_dict(axes)['C']]
     b = self.config.train_batch_size**(1.0/self.config.n_dim)
     n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)]
     if 'C' in axes:
         n_tiles.insert(axes_dict(axes)['C'],1)
     return tuple(n_tiles)
Esempio n. 3
0
    def _predict(imdims,axes):
        img = rng.uniform(size=imdims)
        n_tiles = [1]*len(axes)
        ax = axes_dict(axes)

        if config.probabilistic:
            prob = model.predict_probabilistic(img, axes, None, None)
            mean, scale = prob.mean(), prob.scale()
            assert mean.shape == scale.shape
        else:
            mean = model.predict(img, axes, None, None)

            n_tiles[ax['X']] = 3
            n_tiles[ax['Y']] = 2
            mean_tiled = model.predict(img, axes, None, None, n_tiles=n_tiles)
            error_max = np.max(np.abs(mean-mean_tiled))
            # print(n_tiles, error_max)
            assert error_max < 1e-3

            with pytest.raises(ValueError):
                n_tiles[ax[proj_axis]] = 2
                model.predict(img, axes, None, None, n_tiles=n_tiles)

        shape_out = list(imdims)
        if 'C' in axes:
            shape_out[ax['C']] = config.n_channel_out
        elif config.n_channel_out > 1:
            shape_out.append(config.n_channel_out)

        del shape_out[ax[proj_axis]]
        assert tuple(shape_out) == mean.shape
Esempio n. 4
0
    def prepare_input(self, img, axes='YX', normalizer=None):
        """
        experimental function to extract normalized/resized input for prediction on hidden layers ##
        basically a copy of the original predict function
        currently only assume 2D n_tiles = (1,1)
        """

        n_tiles = [1] * img.ndim
        n_tiles = tuple(map(int, n_tiles))

        axes = self._normalize_axes(img, axes)
        axes_net = self.config.axes
        _permute_axes = self._make_permute_axes(axes, axes_net)
        x = _permute_axes(img)  # x has axes_net semantics

        channel = axes_dict(axes_net)['C']
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
        axes_net_div_by = self._axes_div_by(axes_net)

        grid = tuple(self.config.grid)
        len(grid) == len(axes_net) - 1 or _raise(ValueError())
        grid_dict = dict(zip(axes_net.replace('C', ''), grid))

        normalizer = self._check_normalizer_resizer(normalizer, None)[0]
        resizer = StarDistPadAndCropResizer(grid=grid_dict)

        x = normalizer.before(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)

        sh = list(x.shape)
        sh[channel] = 1
        dummy = np.empty(sh, np.float32)

        return x[np.newaxis], dummy[np.newaxis]
Esempio n. 5
0
 def _gen():
     for i in range(n_images):
         x = rng.uniform(size=shape)
         y = np.mean(x,
                     axis=tuple(axes_dict(axes)[a] for a in red_axes),
                     keepdims=keepdims)
         yield x, y, axes, None
Esempio n. 6
0
def original():
    mypath = Path('isonet_psf_1')
    mypath.mkdir(exist_ok=True)

    # sys.stdout = open(mypath / 'train_stdout.txt', 'w')
    # sys.stderr = open(mypath / 'train_stderr.txt', 'w')

    (X, Y), (X_val, Y_val), data_axes = load_training_data(
        mypath / 'my_training_data.npz', validation_split=0.1)
    ax = axes_dict(data_axes)

    n_train, n_val = len(X), len(X_val)
    image_size = tuple(
        X.shape[i]
        for i in ((ax['Z'], ax['Y'],
                   ax['X']) if (ax['Z'] is not None) else (ax['Y'], ax['X'])))
    n_dim = len(image_size)
    n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

    print('number of training images:\t', n_train)
    print('number of validation images:\t', n_val)
    print('image size (%dD):\t\t' % n_dim, image_size)
    print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    plt.figure(figsize=(10, 4))
    plot_some(X_val[:5], Y_val[:5])
    plt.suptitle(
        '5 example validation patches (top row: source, bottom row: target)')
    plt.savefig(mypath / 'train_1.png')

    config = Config(data_axes, n_channel_in, n_channel_out, train_epochs=200)
    print(config)
    vars(config)

    model = IsotropicCARE(config, str(mypath / 'my_model'))

    history = model.train(X, Y, validation_data=(X_val, Y_val))

    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(mypath / 'train_history.png')

    model.load_weights()  # load best weights according to validation loss

    plt.figure(figsize=(12, 7))
    _P = model.keras_model.predict(X_val[:5])
    if config.probabilistic:
        _P = _P[..., :(_P.shape[-1] // 2)]
    plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5)
    plt.suptitle('5 example validation patches\n' +
                 'top row: input (source),  ' +
                 'middle row: target (ground truth),  ' +
                 'bottom row: predicted from source')
    plt.tight_layout()
    plt.savefig(mypath / 'train_2.png')

    model.export_TF()
Esempio n. 7
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Esempio n. 8
0
 def _predict(imdims,axes):
     img = rng.uniform(size=imdims)
     if config.probabilistic:
         prob = model.predict_probabilistic(img, axes, factor, None, None)
         mean, scale = prob.mean(), prob.scale()
         assert mean.shape == scale.shape
     else:
         mean = model.predict(img, axes, factor, None, None)
     a = axes_dict(axes)['Z']
     assert imdims[a]*factor == mean.shape[a]
Esempio n. 9
0
def dev(args):
    import json

    # Load and parse training data
    (X,
     Y), (X_val,
          Y_val), axes = load_training_data(args.train_data,
                                            validation_split=args.valid_split,
                                            axes=args.axes,
                                            verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    # Model config
    print('args.resume: ', args.resume)
    if args.resume:
        # If resuming, config=None will reload the saved config
        config = None
        print('Attempting to resume')
    elif args.config:
        print('loading config from args')
        config_args = json.load(open(args.config))
        config = Config(**config_args)
    else:
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        probabilistic=args.prob,
                        train_steps_per_epoch=args.steps,
                        train_epochs=args.epochs)
        print(vars(config))

    # Load or init model
    model = CARE(config, args.model_name, basedir='models')

    # Training, tensorboard available
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    # Plot training results
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(args.model_name + '_training.png')

    # Export model to be used w/ csbdeep fiji plugins and KNIME flows
    model.export_TF()
Esempio n. 10
0
    def _predict(imdims, axes):
        img = rng.uniform(size=imdims)
        # print(img.shape, axes, config.n_channel_out)
        mean, scale = model._predict_mean_and_scale(img, axes, normalizer,
                                                    resizer)
        if config.probabilistic:
            assert mean.shape == scale.shape
        else:
            assert scale is None

        if 'C' not in axes:
            if config.n_channel_out == 1:
                assert mean.shape == img.shape
            else:
                assert mean.shape == img.shape + (config.n_channel_out, )
        else:
            channel = axes_dict(axes)['C']
            imdims[channel] = config.n_channel_out
            assert mean.shape == tuple(imdims)
Esempio n. 11
0
    def __init__(self, X, **kwargs):
        """See class docstring"""

        # X is empty if config is None
        if  X.size != 0:
            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            self.n_channel_in = 1
            self.n_channel_out = 4
            self.train_loss = 'denoiseg'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 4
            self.relative_weights = [1.0, 1.0, 5.0]
            self.unet_kern_size = 3
            self.unet_n_first = 32
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            self.train_epochs = 200
            self.train_steps_per_epoch = 400
            self.train_learning_rate = 0.0004
            self.train_batch_size = 128
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = True
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (64, 64) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5
            self.denoiseg_alpha = 0.5

        # disallow setting 'probabilistic' manually
        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Esempio n. 12
0
    def __init__(self, X, **kwargs):
        """See class docstring."""

        assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2
        n_channel_in = X.shape[-1]
        n_channel_out = n_channel_in
        mean = np.mean(X)
        std = np.std(X)

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
        axes = axes.replace('S','') # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
            else:
                axes = 'C'+axes

        # normalization parameters
        self.mean                  = str(mean)
        self.std                   = str(std)
        # directly set by parameters
        self.n_dim                 = n_dim
        self.axes                  = axes
        self.n_channel_in          = int(n_channel_in)
        self.n_channel_out         = int(n_channel_out)

        # default config (can be overwritten by kwargs below)
        self.unet_residual         = False
        self.unet_n_depth          = 2
        self.unet_kern_size        = 5 if self.n_dim==2 else 3
        self.unet_n_first          = 32
        self.unet_last_activation  = 'linear'
        if backend_channels_last():
            self.unet_input_shape  = self.n_dim*(None,) + (self.n_channel_in,)
        else:
            self.unet_input_shape  = (self.n_channel_in,) + self.n_dim*(None,)

        self.train_loss            = 'mae'
        self.train_epochs          = 100
        self.train_steps_per_epoch = 400
        self.train_learning_rate   = 0.0004
        self.train_batch_size      = 16
        self.train_tensorboard     = True
        self.train_checkpoint      = 'weights_best.h5'
        self.train_reduce_lr       = {'factor': 0.5, 'patience': 10}
        self.batch_norm            = True
        self.n2v_perc_pix           = 1.5
        self.n2v_patch_shape       = (64, 64) if self.n_dim==2 else (64, 64, 64)
        self.n2v_manipulator       = 'uniform_withCP'
        self.n2v_neighborhood_radius = 5

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Esempio n. 13
0
    def train(self, X,Y, validation_data, augmenter=None, seed=None, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of input images.
        Y : :class:`numpy.ndarray`
            Array of label masks.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of X,Y validation arrays.
        augmenter : None or callable
            Function with expected signature ``Xbt, Ybt = augmenter(Xb, Yb)``
            that takes in batch input/label images (Xb,Yb) and returns
            transformed images (Xbt, Ybt) for the purpose of data augmentation
            during training. Not applied to validation images.
        seed : int
            Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """
        if seed is not None:
            # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
            np.random.seed(seed)
        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        validation_data is not None or _raise(ValueError())
        ((isinstance(validation_data,(list,tuple)) and len(validation_data)==2)
            or _raise(ValueError('validation_data must be a pair of numpy arrays')))

        patch_size = self.config.train_patch_size
        axes = self.config.axes.replace('C','')
        div_by = self._axes_div_by(axes)
        [p % d == 0 or _raise(ValueError(
            "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
         )) for p,d,a in zip(patch_size,div_by,axes)]

        if not self._model_prepared:
            self.prepare_for_training()

        data_kwargs = dict (
            rays                  = rays_from_json(self.config.rays_json),
            grid                  = self.config.grid,
            patch_size            = self.config.train_patch_size,
            anisotropy            = self.config.anisotropy,
            use_gpu               = self.config.use_gpu,
        )

        # generate validation data and store in numpy arrays
        # data_val = StarDistData3D(*validation_data, batch_size=1, augment=False, **data_kwargs)
        _data_val = StarDistData3D(*validation_data, batch_size=1, **data_kwargs)
        n_data_val = len(_data_val)
        n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
        ids = tuple(np.random.choice(n_data_val, size=n_take, replace=(n_take > n_data_val)))
        Xv, Mv, Pv, Dv = [None]*n_take, [None]*n_take, [None]*n_take, [None]*n_take
        for i,k in enumerate(ids):
            (Xv[i],Mv[i]),(Pv[i],Dv[i]) = _data_val[k]
        Xv, Mv, Pv, Dv = np.concatenate(Xv,axis=0), np.concatenate(Mv,axis=0), np.concatenate(Pv,axis=0), np.concatenate(Dv,axis=0)
        data_val = [[Xv,Mv],[Pv,Dv]]

        data_train = StarDistData3D(X, Y, batch_size=self.config.train_batch_size, augmenter=augmenter, **data_kwargs)

        for cb in self.callbacks:
            if isinstance(cb,CARETensorBoard):
                # only show middle slice of 3D inputs/outputs
                cb.input_slices, cb.output_slices = [[slice(None)]*5,[slice(None)]*5], [[slice(None)]*5,[slice(None)]*5]
                i = axes_dict(self.config.axes)['Z']
                _n_in  = _data_val.patch_size[i] // 2
                _n_out = _data_val.patch_size[i] // (2 * (self.config.grid[i] if self.config.grid is not None else 1))
                cb.input_slices[0][1+i] = _n_in
                cb.input_slices[1][1+i] = _n_out
                cb.output_slices[0][1+i] = _n_out
                cb.output_slices[1][1+i] = _n_out
                # show dist for three rays
                _n = min(3, self.config.n_rays)
                cb.output_slices[1][1+axes_dict(self.config.axes)['C']] = slice(0,(self.config.n_rays//_n)*_n,self.config.n_rays//_n)

        history = self.keras_model.fit_generator(generator=data_train, validation_data=data_val,
                                                 epochs=epochs, steps_per_epoch=steps_per_epoch,
                                                 callbacks=self.callbacks, verbose=1)
        self._training_finished()

        return history
Esempio n. 14
0
    def plugin(
        viewer: napari.Viewer,
        label_head,
        image: napari.layers.Image,
        axes,
        label_nn,
        model_type,
        model2d,
        model3d,
        model_folder,
        model_axes,
        norm_image,
        perc_low,
        perc_high,
        input_scale,
        label_nms,
        prob_thresh,
        nms_thresh,
        output_type,
        label_adv,
        n_tiles,
        norm_axes,
        timelapse_opts,
        cnn_output,
        set_thresholds,
        defaults_button,
        progress_bar: mw.ProgressBar,
    ) -> List[napari.types.LayerDataTuple]:

        model = get_model(*model_selected)
        if model._is_multiclass():
            warn(
                "multi-class mode not supported yet, ignoring classification output"
            )

        lkwargs = {}
        x = get_data(image)
        axes = axes_check_and_normalize(axes, length=x.ndim)

        if not (input_scale is None
                or isinstance(input_scale, numbers.Number)):
            input_scale = tuple(s for a, s in zip(axes, input_scale)
                                if a not in ("T", ))
            # print(f'scaling by {input_scale}')

        if not axes.replace("T", "").startswith(
                model._axes_out.replace("C", "")):
            warn(
                f"output images have different axes ({model._axes_out.replace('C','')}) than input image ({axes})"
            )
            # TODO: adjust image.scale according to shuffled axes

        if norm_image:
            axes_norm = axes_check_and_normalize(norm_axes)
            axes_norm = "".join(set(axes_norm).intersection(
                set(axes)))  # relevant axes present in input image
            assert len(axes_norm) > 0
            # always jointly normalize channels for RGB images
            if ("C" in axes and image.rgb == True) and ("C" not in axes_norm):
                axes_norm = axes_norm + "C"
                warn("jointly normalizing channels of RGB input image")
            ax = axes_dict(axes)
            _axis = tuple(sorted(ax[a] for a in axes_norm))
            # # TODO: address joint vs. channel/time-separate normalization properly (let user choose)
            # #       also needs to be documented somewhere
            # if 'T' in axes:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly, frames independently
            #          _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],))
            #     else:
            #         # normalize channels independently, frames independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],ax['C']))
            # else:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly
            #         _axis = None
            #     else:
            #         # normalize channels independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],))
            x = normalize(x, perc_low, perc_high, axis=_axis)

        # TODO: progress bar (labels) often don't show up. events not processed?
        if "T" in axes:
            app = use_app()
            t = axes_dict(axes)["T"]
            n_frames = x.shape[t]
            if n_tiles is not None:
                # remove tiling value for time axis
                n_tiles = tuple(v for i, v in enumerate(n_tiles) if i != t)

            def progress(it, **kwargs):
                progress_bar.label = "StarDist Prediction (frames)"
                progress_bar.range = (0, n_frames)
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                app.process_events()

        elif n_tiles is not None and np.prod(n_tiles) > 1:
            n_tiles = tuple(n_tiles)
            app = use_app()

            def progress(it, **kwargs):
                progress_bar.label = "CNN Prediction (tiles)"
                progress_bar.range = (0, kwargs.get("total", 0))
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                #
                progress_bar.label = "NMS Postprocessing"
                progress_bar.range = (0, 0)
                app.process_events()

        else:
            progress = False
            progress_bar.label = "StarDist Prediction"
            progress_bar.range = (0, 0)
            progress_bar.show()
            use_app().process_events()

        # semantic output axes of predictions
        assert model._axes_out[-1] == "C"
        axes_out = list(model._axes_out[:-1])

        if "T" in axes:
            x_reorder = np.moveaxis(x, t, 0)
            axes_reorder = axes.replace("T", "")
            axes_out.insert(t, "T")
            res = tuple(
                zip(*tuple(
                    model.predict_instances(
                        _x,
                        axes=axes_reorder,
                        prob_thresh=prob_thresh,
                        nms_thresh=nms_thresh,
                        n_tiles=n_tiles,
                        scale=input_scale,
                        sparse=(not cnn_output),
                        return_predict=cnn_output,
                    ) for _x in progress(x_reorder))))

            if cnn_output:
                labels, polys = tuple(zip(*res[0]))
                cnn_output = tuple(np.stack(c, t) for c in tuple(zip(*res[1])))
            else:
                labels, polys = res

            labels = np.asarray(labels)

            if len(polys) > 1:
                if timelapse_opts == TimelapseLabels.Match.value:
                    # match labels in consecutive frames (-> simple IoU tracking)
                    labels = group_matching_labels(labels)
                elif timelapse_opts == TimelapseLabels.Unique.value:
                    # make label ids unique (shift by offset)
                    offsets = np.cumsum([len(p["points"]) for p in polys])
                    for y, off in zip(labels[1:], offsets):
                        y[y > 0] += off
                elif timelapse_opts == TimelapseLabels.Separate.value:
                    # each frame processed separately (nothing to do)
                    pass
                else:
                    raise NotImplementedError(
                        f"unknown option '{timelapse_opts}' for time-lapse labels"
                    )

            labels = np.moveaxis(labels, 0, t)

            if isinstance(model, StarDist3D):
                # TODO poly output support for 3D timelapse
                polys = None
            else:
                polys = dict(
                    coord=np.concatenate(
                        tuple(
                            np.insert(p["coord"], t, _t, axis=-2)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                    points=np.concatenate(
                        tuple(
                            np.insert(p["points"], t, _t, axis=-1)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                )

            if cnn_output:
                pred = (labels, polys), cnn_output
            else:
                pred = labels, polys

        else:
            # TODO: possible to run this in a way that it can be canceled?
            pred = model.predict_instances(
                x,
                axes=axes,
                prob_thresh=prob_thresh,
                nms_thresh=nms_thresh,
                n_tiles=n_tiles,
                show_tile_progress=progress,
                scale=input_scale,
                sparse=(not cnn_output),
                return_predict=cnn_output,
            )
        progress_bar.hide()

        # determine scale for output axes
        scale_in_dict = dict(zip(axes, image.scale))
        scale_out = [scale_in_dict.get(a, 1.0) for a in axes_out]

        layers = []
        if cnn_output:
            (labels, polys), cnn_out = pred
            prob, dist = cnn_out[:2]
            dist = np.moveaxis(dist, -1, 0)

            assert len(model.config.grid) == len(model.config.axes) - 1
            grid_dict = dict(
                zip(model.config.axes.replace("C", ""), model.config.grid))
            # scale output axes to match input axes
            _scale = [
                s * grid_dict.get(a, 1) for a, s in zip(axes_out, scale_out)
            ]
            # small translation correction if grid > 1 (since napari centers objects)
            _translate = [0.5 * (grid_dict.get(a, 1) - 1) for a in axes_out]

            layers.append((
                dist,
                dict(
                    name="StarDist distances",
                    scale=[1] + _scale,
                    translate=[0] + _translate,
                    **lkwargs,
                ),
                "image",
            ))
            layers.append((
                prob,
                dict(
                    name="StarDist probability",
                    scale=_scale,
                    translate=_translate,
                    **lkwargs,
                ),
                "image",
            ))
        else:
            labels, polys = pred

        if output_type in (Output.Labels.value, Output.Both.value):
            layers.append((
                labels,
                dict(name="StarDist labels",
                     scale=scale_out,
                     opacity=0.5,
                     **lkwargs),
                "labels",
            ))
        if output_type in (Output.Polys.value, Output.Both.value):
            n_objects = len(polys["points"])
            if isinstance(model, StarDist3D):
                surface = surface_from_polys(polys)
                layers.append((
                    surface,
                    dict(
                        name="StarDist polyhedra",
                        contrast_limits=(0, surface[-1].max()),
                        scale=scale_out,
                        colormap=label_colormap(n_objects),
                        **lkwargs,
                    ),
                    "surface",
                ))
            else:
                # TODO: sometimes hangs for long time (indefinitely?) when returning many polygons (?)
                #       seems to be a known issue: https://github.com/napari/napari/issues/2015
                # TODO: coordinates correct or need offset (0.5 or so)?
                shapes = np.moveaxis(polys["coord"], -1, -2)
                layers.append((
                    shapes,
                    dict(
                        name="StarDist polygons",
                        shape_type="polygon",
                        scale=scale_out,
                        edge_width=0.75,
                        edge_color="yellow",
                        face_color=[0, 0, 0, 0],
                        **lkwargs,
                    ),
                    "shapes",
                ))
        return layers
Esempio n. 15
0
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, ProjectionCARE

# In[2]:

(X, Y), (X_val, Y_val), axes = load_training_data(
    '/local/u934/private/v_kapoor/CurieTrainingDatasets/Drosophilla/DenoisingProjection.npz',
    validation_split=0.1,
    verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# In[3]:

# In[4]:

config = Config(axes,
                n_channel_in,
                n_channel_out,
                unet_n_depth=4,
                train_epochs=50,
                train_steps_per_epoch=400,
                train_batch_size=16,
                train_reduce_lr={
                    'patience': 5,
Esempio n. 16
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Esempio n. 17
0
    def __init__(self, X,**kwargs):

        if  X.size != 0:

            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

            n_dim = len(X.shape) - 2

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'

            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}

            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S', '')  # remove sample axis if it exists

            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C' + axes

            means, stds = [], []
            for i in range(X.shape[-1]):
                means.append(np.mean(X[..., i]))
                stds.append(np.std(X[..., i]))

            # normalization parameters
            self.means = [str(el) for el in means]
            self.stds = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim = n_dim
            self.axes = axes
            # fixed parameters
            if 'C' in axes:
                self.n_channel_in = X.shape[-1]
            else:
                self.n_channel_in = 1
            self.train_loss = 'demix'

            # default config (can be overwritten by kwargs below)

            self.unet_n_depth = 2
            self.unet_kern_size = 3
            self.unet_n_first = 64
            self.unet_last_activation = 'linear'
            self.probabilistic = False
            self.unet_residual = False
            if backend_channels_last():
                self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,)

            # fixed parameters
            self.train_epochs = 200
            self.train_steps_per_epoch = 50
            self.train_learning_rate = 0.0004
            self.train_batch_size = 64
            self.train_tensorboard = False
            self.train_checkpoint = 'weights_best.h5'
            self.train_checkpoint_last  = 'weights_last.h5'
            self.train_checkpoint_epoch = 'weights_now.h5'
            self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10}
            self.batch_norm = False
            self.n2v_perc_pix = 1.5
            self.n2v_patch_shape = (128, 128) if self.n_dim == 2 else (64, 64, 64)
            self.n2v_manipulator = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5

            self.single_net_per_channel = False


            self.channel_denoised = False
            self.multi_objective = True
            self.normalizer = 'none'
            self.weights_objectives = [0.1, 0, 0.45, 0.45]
            self.distributions = 'gauss'
            self.n2v_leave_center = False
            self.scale_aug = False
            self.structN2Vmask = None

            self.n_back_modes = 2
            self.n_fore_modes = 2
            self.n_instance_seg = 0
            self.n_back_i_modes = 1
            self.n_fore_i_modes = 1

            self.fit_std = False
            self.fit_mean = True



            # self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components +\
            #                      self.n_instance_seg * (self.n_back_i_modes+self.n_fore_i_modes)

        try:
            kwargs['probabilistic'] = False
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            kwargs['unet_residual'] = False
        except:
            pass

        # print('KWARGS')
        for k in kwargs:

            # print(k,  kwargs[k])
            setattr(self, k, kwargs[k])
        self.n_components = 3 if (self.fit_std & self.fit_mean) else 2
        self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components + \
                             self.n_instance_seg * (self.n_back_i_modes + self.n_fore_i_modes)
Esempio n. 18
0
    def __init__(self, X, **kwargs):
        """See class docstring"""
        assert X.size != 0
        assert len(X.shape) == 4 or len(
            X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

        n_dim = len(X.shape) - 2

        if n_dim == 2:
            axes = 'SYXC'
        elif n_dim == 3:
            axes = 'SZYXC'

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(
            ValueError('lateral axes X and Y must be present.'))
        not (ax['Z'] and ax['T']) or _raise(
            ValueError('using Z and T axes together not supported.'))

        axes.startswith('S') or (not ax['S']) or _raise(
            ValueError('sample axis S must be first.'))
        axes = axes.replace('S', '')  # remove sample axis if it exists

        if backend_channels_last():
            if ax['C']:
                axes[-1] == 'C' or _raise(
                    ValueError('channel axis must be last for backend (%s).' %
                               K.backend()))
            else:
                axes += 'C'
        else:
            if ax['C']:
                axes[0] == 'C' or _raise(
                    ValueError('channel axis must be first for backend (%s).' %
                               K.backend()))
            else:
                axes = 'C' + axes

        assert X.shape[-1] == 1
        n_channel_in = 1
        n_channel_out = 3

        # directly set by parameters
        self.n_dim = n_dim
        self.axes = axes
        # fixed parameters
        self.n_channel_in = n_channel_in
        self.n_channel_out = n_channel_out
        self.train_loss = 'seg'

        # default config (can be overwritten by kwargs below)

        self.unet_n_depth = 4
        self.relative_weights = [1.0, 1.0, 5.0]
        self.unet_kern_size = 3
        self.unet_n_first = 32
        self.unet_last_activation = 'linear'
        self.probabilistic = False
        self.unet_residual = False
        if backend_channels_last():
            self.unet_input_shape = self.n_dim * (None, ) + (
                self.n_channel_in, )
        else:
            self.unet_input_shape = (
                self.n_channel_in, ) + self.n_dim * (None, )

        self.train_epochs = 200
        self.train_steps_per_epoch = 400
        self.train_learning_rate = 0.0004
        self.train_batch_size = 128
        self.train_tensorboard = False
        self.train_checkpoint = 'weights_best.h5'
        self.train_checkpoint_last = 'weights_last.h5'
        self.train_checkpoint_epoch = 'weights_now.h5'
        self.train_reduce_lr = {'factor': 0.5, 'patience': 10}
        self.batch_norm = True

        # disallow setting 'n_dim' manually
        try:
            del kwargs['n_dim']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_in' manually
        try:
            del kwargs['n_channel_in']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'n_channel_out' manually
        try:
            del kwargs['n_channel_out']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'train_loss' manually
        try:
            del kwargs['train_loss']
            # warnings.warn("ignoring parameter 'n_dim'")
        except:
            pass
        # disallow setting 'probabilistic' manually
        try:
            del kwargs['probabilistic']
        except:
            pass
        # disallow setting 'unet_residual' manually
        try:
            del kwargs['unet_residual']
        except:
            pass

        for k in kwargs:
            setattr(self, k, kwargs[k])
Esempio n. 19
0
        def _update(self):

            # try to get a hold of the viewer (can be None when plugin starts)
            if self.viewer is None:
                # TODO: when is this not safe to do and will hang forever?
                # while plugin.viewer.value is None:
                #     time.sleep(0.01)
                if plugin.viewer.value is not None:
                    self.viewer = plugin.viewer.value
                    if DEBUG:
                        print("GOT viewer")

                    @self.viewer.layers.events.removed.connect
                    def _layer_removed(event):
                        layers_remaining = event.source
                        if len(layers_remaining) == 0:
                            plugin.image.tooltip = ""
                            plugin.axes.value = ""
                            plugin.n_tiles.value = "None"
                            plugin.input_scale.value = "None"

            def _model(valid):
                widgets_valid(
                    plugin.model2d,
                    plugin.model3d,
                    plugin.model_folder.line_edit,
                    valid=valid,
                )
                if valid:
                    config = self.args.model
                    axes = config.get("axes",
                                      "ZYXC"[-len(config["net_input_shape"]):])
                    if "T" in axes:
                        raise RuntimeError("model with axis 'T' not supported")
                    plugin.model_axes.value = axes.replace(
                        "C", f"C[{config['n_channel_in']}]")
                    plugin.model_folder.line_edit.tooltip = ""
                    return axes, config
                else:
                    plugin.model_axes.value = ""
                    plugin.model_folder.line_edit.tooltip = "Invalid model directory"

            def _image_axes(valid):
                axes, image, err = getattr(self.args, "image_axes",
                                           (None, None, None))
                widgets_valid(
                    plugin.axes,
                    valid=(valid or (image is None and
                                     (axes is None or len(axes) == 0))),
                )
                if (valid and "T" in axes and plugin.output_type.value
                        in (Output.Polys.value, Output.Both.value)):
                    plugin.output_type.native.setStyleSheet(
                        "background-color: orange")
                    plugin.output_type.tooltip = (
                        "Displaying many polygons/polyhedra can be very slow.")
                else:
                    plugin.output_type.native.setStyleSheet("")
                    plugin.output_type.tooltip = ""
                if valid:
                    plugin.axes.tooltip = "\n".join([
                        f"{a} = {s}" for a, s in zip(axes,
                                                     get_data(image).shape)
                    ])
                    return axes, image
                else:
                    if err is not None:
                        err = str(err)
                        err = err[:-1] if err.endswith(".") else err
                        plugin.axes.tooltip = err
                        # warn(err) # alternative to tooltip (gui doesn't show up in ipython)
                    else:
                        plugin.axes.tooltip = ""

            def _norm_axes(valid):
                norm_axes, err = getattr(self.args, "norm_axes", (None, None))
                widgets_valid(plugin.norm_axes, valid=valid)
                if valid:
                    plugin.norm_axes.tooltip = f"Axes to jointly normalize (if present in selected input image). Note: channels of RGB images are always normalized together."
                    return norm_axes
                else:
                    if err is not None:
                        err = str(err)
                        err = err[:-1] if err.endswith(".") else err
                        plugin.norm_axes.tooltip = err
                        # warn(err) # alternative to tooltip (gui doesn't show up in ipython)
                    else:
                        plugin.norm_axes.tooltip = ""

            def _n_tiles(valid):
                n_tiles, image, err = getattr(self.args, "n_tiles",
                                              (None, None, None))
                widgets_valid(plugin.n_tiles, valid=(valid or image is None))
                if valid:
                    plugin.n_tiles.tooltip = (
                        "no tiling" if n_tiles is None else "\n".join([
                            f"{t}: {s}"
                            for t, s in zip(n_tiles,
                                            get_data(image).shape)
                        ]))
                    return n_tiles
                else:
                    msg = str(err) if err is not None else ""
                    plugin.n_tiles.tooltip = msg

            def _no_tiling_for_axis(axes_image, n_tiles, axis):
                if n_tiles is not None and axis in axes_image:
                    return n_tiles[axes_dict(axes_image)[axis]] == 1
                return True

            def _input_scale(valid):
                input_scale, image, err = getattr(self.args, "input_scale",
                                                  (None, None, None))
                widgets_valid(plugin.input_scale,
                              valid=(valid or image is None))
                if valid:
                    if input_scale is None:
                        plugin.input_scale.tooltip = "no scaling"
                    elif isinstance(input_scale, numbers.Number):
                        plugin.input_scale.tooltip = f"{input_scale} for all spatial axes"
                    else:
                        assert len(input_scale) == len(get_data(image).shape)
                        plugin.input_scale.tooltip = "\n".join(
                            [f"{s}" for s in input_scale])
                    return input_scale
                else:
                    msg = str(err) if err is not None else ""
                    plugin.input_scale.tooltip = msg

            def _input_scale_check(axes_image, input_scale):
                if input_scale is not None and not isinstance(
                        input_scale, numbers.Number):
                    assert len(input_scale) == len(axes_image)
                    # s != 1 only allowed for spatial axes XYZ
                    return all(s == 1 or a in "XYZ"
                               for a, s in zip(axes_image, input_scale))
                return True

            def _restore():
                widgets_valid(plugin.image,
                              valid=plugin.image.value is not None)

            all_valid = False
            help_msg = ""

            if (self.valid.image_axes and self.valid.n_tiles
                    and self.valid.model and self.valid.norm_axes
                    and self.valid.input_scale):
                axes_image, image = _image_axes(True)
                axes_model, config = _model(True)
                axes_norm = _norm_axes(True)
                n_tiles = _n_tiles(True)
                input_scale = _input_scale(True)

                if not _no_tiling_for_axis(axes_image, n_tiles, "C"):
                    # check if image axes and n_tiles are compatible
                    widgets_valid(plugin.n_tiles, valid=False)
                    err = "number of tiles must be 1 for C axis"
                    plugin.n_tiles.tooltip = err
                    _restore()
                elif not _no_tiling_for_axis(axes_image, n_tiles, "T"):
                    # check if image axes and n_tiles are compatible
                    widgets_valid(plugin.n_tiles, valid=False)
                    err = "number of tiles must be 1 for T axis"
                    plugin.n_tiles.tooltip = err
                    _restore()
                elif not _input_scale_check(axes_image, input_scale):
                    # check if image axes and input_scale are compatible
                    widgets_valid(plugin.input_scale, valid=False)
                    _violations = ", ".join(
                        a for a, s in zip(axes_image, input_scale)
                        if not (s == 1 or a in "XYZ"))
                    err = f"values for non-spatial axes ({_violations}) must be 1"
                    plugin.input_scale.tooltip = err
                    _restore()
                elif set(axes_norm).isdisjoint(set(axes_image)):
                    # check if image axes and normalization axes are compatible
                    widgets_valid(plugin.norm_axes, valid=False)
                    err = f"Image axes ({axes_image}) must contain at least one of the normalization axes ({', '.join(axes_norm)})"
                    plugin.norm_axes.tooltip = err
                    _restore()
                elif ("T" in axes_image and config.get("n_dim") == 3
                      and plugin.output_type.value
                      in (Output.Polys.value, Output.Both.value)):
                    # not supported
                    widgets_valid(plugin.output_type, valid=False)
                    plugin.output_type.tooltip = (
                        "Polyhedra output currently not supported for 3D timelapse data"
                    )
                    _restore()
                else:
                    # tooltip for input_scale
                    if isinstance(input_scale, numbers.Number):
                        plugin.input_scale.tooltip = "\n".join([
                            f'{a} = {input_scale if a in "XYZ" else 1}'
                            for a in axes_image
                        ])
                    elif input_scale is not None:
                        plugin.input_scale.tooltip = "\n".join([
                            f"{a} = {s}"
                            for a, s in zip(axes_image, input_scale)
                        ])

                    # check if image and model are compatible
                    ch_model = config["n_channel_in"]
                    ch_image = (
                        get_data(image).shape[axes_dict(axes_image)["C"]]
                        if "C" in axes_image else 1)
                    all_valid = (set(axes_model.replace("C", "")) == set(
                        axes_image.replace("C", "").replace("T", ""))
                                 and ch_model == ch_image)

                    widgets_valid(
                        plugin.image,
                        plugin.model2d,
                        plugin.model3d,
                        plugin.model_folder.line_edit,
                        valid=all_valid,
                    )
                    if all_valid:
                        help_msg = ""
                    else:
                        help_msg = f'Model with axes {axes_model.replace("C", f"C[{ch_model}]")} and image with axes {axes_image.replace("C", f"C[{ch_image}]")} not compatible'
            else:
                _image_axes(self.valid.image_axes)
                _norm_axes(self.valid.norm_axes)
                _n_tiles(self.valid.n_tiles)
                _input_scale(self.valid.input_scale)
                _model(self.valid.model)
                _restore()

            self.help(help_msg)
            plugin.call_button.enabled = all_valid
            # widgets_valid(plugin.call_button, valid=all_valid)
            if self.debug:
                print(
                    f"valid ({all_valid}):",
                    ", ".join(
                        [f"{k}={v}" for k, v in vars(self.valid).items()]),
                )
    def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.
        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        Y : :class:`numpy.ndarray`
            Array of target images.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of arrays for source and target validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.
        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.
        """

        ((isinstance(validation_data,
                     (list, tuple)) and len(validation_data) == 2) or
         _raise(ValueError('validation_data must be a pair of numpy arrays')))

        n_train, n_val = len(X), len(validation_data[0])
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.1f%% of all images)"
                % (100 * frac_val))
        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)

        for a, div_by in zip(axes, self._axes_div_by(axes)):
            n = X.shape[ax[a]]
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axis %s"
                    " (which has incompatible size %d)" % (div_by, a, n))

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        training_data = CryoDataWrapper(X, Y, self.config.train_batch_size)

        history = self.keras_model.fit_generator(
            generator=training_data,
            validation_data=validation_data,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            callbacks=self.callbacks,
            verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Esempio n. 21
0
    def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None):
        n_train, n_val = len(X), len(validation_data[0])
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.05f%% of all images)"
                % (100 * frac_val))
        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)
        div_by = 2**self.config.unet_n_depth
        axes_relevant = ''.join(a for a in 'XYZT' if a in axes)
        val_num_pix = 1
        train_num_pix = 1
        val_patch_shape = ()
        for a in axes_relevant:
            n = X.shape[ax[a]]
            val_num_pix *= validation_data[0].shape[ax[a]]
            train_num_pix *= X.shape[ax[a]]
            val_patch_shape += tuple([validation_data[0].shape[ax[a]]])
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axes %s"
                    " (axis %s has incompatible size %d)" %
                    (div_by, axes_relevant, a, n))

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        manipulator = eval('pm_{0}({1})'.format(
            self.config.n2v_manipulator,
            str(self.config.n2v_neighborhood_radius)))

        means = np.array([float(mean) for mean in self.config.means],
                         ndmin=len(X.shape),
                         dtype=np.float32)
        stds = np.array([float(std) for std in self.config.stds],
                        ndmin=len(X.shape),
                        dtype=np.float32)

        X = self.__normalize__(X, means, stds)
        validation_X = self.__normalize__(validation_data[0], means, stds)

        # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
        # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
        training_data = DenoiSeg_DataWrapper(
            X=X,
            n2v_Y=np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)),
                                 axis=axes.index('C')),
            seg_Y=Y,
            batch_size=self.config.train_batch_size,
            perc_pix=self.config.n2v_perc_pix,
            shape=self.config.n2v_patch_shape,
            value_manipulation=manipulator)

        # validation_Y is also validation_X plus a concatenated masking channel.
        # To speed things up, we precompute the masking vo the validation data.
        validation_Y = np.concatenate(
            (validation_X,
             np.zeros(validation_X.shape, dtype=validation_X.dtype)),
            axis=axes.index('C'))
        n2v_utils.manipulate_val_data(validation_X,
                                      validation_Y,
                                      perc_pix=self.config.n2v_perc_pix,
                                      shape=val_patch_shape,
                                      value_manipulation=manipulator)

        validation_Y = np.concatenate((validation_Y, validation_data[1]),
                                      axis=-1)

        history = self.keras_model.fit(training_data,
                                       validation_data=(validation_X,
                                                        validation_Y),
                                       epochs=epochs,
                                       steps_per_epoch=steps_per_epoch,
                                       callbacks=self.callbacks,
                                       verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Esempio n. 22
0
def load_training_data_direct(X, Y, validation_split=0, axes=None, n_images=None, verbose=False):
    """Load training data from file in ``.npz`` format.

        The data file is expected to have the keys:

        - ``X``    : Array of training input images.
        - ``Y``    : Array of corresponding target images.
        - ``axes`` : Axes of the training images.


        Parameters
        ----------
        file : str
            File name
        validation_split : float
            Fraction of images to use as validation set during training.
        axes: str, optional
            Must be provided in case the loaded data does not contain ``axes`` information.
        n_images : int, optional
            Can be used to limit the number of images loaded from data.
        verbose : bool, optional
            Can be used to display information about the loaded images.

        Returns
        -------
        tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str )
            Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets
            and the axes of the input images.
            The tuple of validation data will be ``None`` if ``validation_split = 0``.

        """

    # f = np.load(file)
    # X, Y = f['X'], f['Y']
    # if axes is None:
    #    axes = f['axes']
    axes = axes_check_and_normalize(axes)

    assert X.shape == Y.shape
    assert len(axes) == X.ndim
    assert 'C' in axes
    if n_images is None:
        n_images = X.shape[0]
    assert X.shape[0] == Y.shape[0]
    assert 0 < n_images <= X.shape[0]
    assert 0 <= validation_split < 1

    X, Y = X[:n_images], Y[:n_images]
    channel = axes_dict(axes)['C']

    if validation_split > 0:
        n_val = int(round(n_images * validation_split))
        n_train = n_images - n_val
        assert 0 < n_val and 0 < n_train
        X_t, Y_t = X[-n_val:], Y[-n_val:]
        X, Y = X[:n_train], Y[:n_train]
        assert X.shape[0] == n_train and X_t.shape[0] == n_val
        X_t = move_channel_for_backend(X_t, channel=channel)
        Y_t = move_channel_for_backend(Y_t, channel=channel)

    X = move_channel_for_backend(X, channel=channel)
    Y = move_channel_for_backend(Y, channel=channel)

    axes = axes.replace('C', '')  # remove channel
    if backend_channels_last():
        axes = axes + 'C'
    else:
        axes = axes[:1] + 'C' + axes[1:]

    data_val = (X_t, Y_t) if validation_split > 0 else None

    if verbose:
        ax = axes_dict(axes)
        n_train, n_val = len(X), len(X_t) if validation_split > 0 else 0
        image_size = tuple(X.shape[ax[a]] for a in 'TZYX' if a in axes)
        n_dim = len(image_size)
        n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

        print('number of training images:\t', n_train)
        print('number of validation images:\t', n_val)
        print('image size (%dD):\t\t' % n_dim, image_size)
        print('axes:\t\t\t\t', axes)
        print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    return (X, Y), data_val, axes
Esempio n. 23
0
    def train(self,
              X,
              Y,
              validation_data,
              augmenter=None,
              seed=None,
              epochs=None,
              steps_per_epoch=None):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
            Input images
        Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
            Label masks
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of X,Y validation arrays.
        augmenter : None or callable
            Function with expected signature ``xt, yt = augmenter(x, y)``
            that takes in a single pair of input/label image (x,y) and returns
            the transformed images (xt, yt) for the purpose of data augmentation
            during training. Not applied to validation images.
            Example:
            def simple_augmenter(x,y):
                x = x + 0.05*np.random.normal(0,1,x.shape)
                return x,y
        seed : int
            Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """
        if seed is not None:
            # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
            np.random.seed(seed)
        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        validation_data is not None or _raise(ValueError())
        ((isinstance(validation_data,
                     (list, tuple)) and len(validation_data) == 2) or
         _raise(ValueError('validation_data must be a pair of numpy arrays')))

        patch_size = self.config.train_patch_size
        axes = self.config.axes.replace('C', '')
        b = self.config.train_completion_crop if self.config.train_shape_completion else 0
        div_by = self._axes_div_by(axes)
        [(p - 2 * b) % d == 0 or _raise(
            ValueError(
                "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'"
                .format(a=a, d=d) if self.config.train_shape_completion else
                "'train_patch_size' must be divisible by {d} along axis '{a}'".
                format(a=a, d=d)))
         for p, d, a in zip(patch_size, div_by, axes)]

        if not self._model_prepared:
            self.prepare_for_training()

        data_kwargs = dict(
            n_params=self.config.n_params,
            patch_size=self.config.train_patch_size,
            grid=self.config.grid,
            shape_completion=self.config.train_shape_completion,
            b=self.config.train_completion_crop,
            use_gpu=self.config.use_gpu,
            foreground_prob=self.config.train_foreground_only,
            contoursize_max=self.config.contoursize_max,
        )

        # generate validation data and store in numpy arrays
        n_data_val = len(validation_data[0])
        n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
        _data_val = SplineDistData2D(*validation_data,
                                     batch_size=n_take,
                                     length=1,
                                     **data_kwargs)
        data_val = _data_val[0]

        data_train = SplineDistData2D(X,
                                      Y,
                                      batch_size=self.config.train_batch_size,
                                      augmenter=augmenter,
                                      length=epochs * steps_per_epoch,
                                      **data_kwargs)

        if self.config.train_tensorboard:
            # show dist for three rays
            _n = min(3, self.config.n_params)
            channel = axes_dict(self.config.axes)['C']
            output_slices = [[slice(None)] * 4, [slice(None)] * 4]
            output_slices[1][1 + channel] = slice(
                0, (self.config.n_params // _n) * _n,
                self.config.n_params // _n)
            if IS_TF_1:
                for cb in self.callbacks:
                    if isinstance(cb, CARETensorBoard):
                        cb.output_slices = output_slices
                        # target image for dist includes dist_mask and thus has more channels than dist output
                        cb.output_target_shapes = [None, [None] * 4]
                        cb.output_target_shapes[1][
                            1 + channel] = data_val[1][1].shape[1 + channel]
            elif self.basedir is not None and not any(
                    isinstance(cb, CARETensorBoardImage)
                    for cb in self.callbacks):
                self.callbacks.append(
                    CARETensorBoardImage(model=self.keras_model,
                                         data=data_val,
                                         log_dir=str(self.logdir / 'logs' /
                                                     'images'),
                                         n_images=3,
                                         prob_out=False,
                                         output_slices=output_slices))

        fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
        history = fit(iter(data_train),
                      validation_data=data_val,
                      epochs=epochs,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=self.callbacks,
                      verbose=1)
        self._training_finished()

        return history
Esempio n. 24
0
    def train(self, X, validation_X, epochs=None, steps_per_epoch=None):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        validation_x : :class:`numpy.ndarray`
            Array of validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """

        n_train, n_val = len(X), len(validation_X)
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn("small number of validation images (only %.1f%% of all images)" % (100*frac_val))
        axes = axes_check_and_normalize('S'+self.config.axes,X.ndim)
        ax = axes_dict(axes)
        div_by = 2**self.config.unet_n_depth
        axes_relevant = ''.join(a for a in 'XYZT' if a in axes)
        val_num_pix = 1
        train_num_pix = 1
        val_patch_shape = ()
        for a in axes_relevant:
            n = X.shape[ax[a]]
            val_num_pix *= validation_X.shape[ax[a]]
            train_num_pix *= X.shape[ax[a]]
            val_patch_shape += tuple([validation_X.shape[ax[a]]])
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axes %s"
                    " (axis %s has incompatible size %d)" % (div_by,axes_relevant,a,n)
                )

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        manipulator = eval('pm_{0}({1})'.format(self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius)))

        mean, std = float(self.config.mean), float(self.config.std)

        X = self.__normalize__(X, mean, std)
        validation_X = self.__normalize__(validation_X, mean, std)

        # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with
        # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating.
        training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')),
                                                    self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix),
                                                    self.config.n2v_patch_shape, manipulator)

        # validation_Y is also validation_X plus a concatinated masking channel.
        # To speed things up, we precomupte the masking vo the validation data.
        validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C'))
        n2v_utils.manipulate_val_data(validation_X, validation_Y,
                                                        num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix),
                                                        shape=val_patch_shape,
                                                        value_manipulation=manipulator)

        history = self.keras_model.fit_generator(generator=training_data, validation_data=(validation_X, validation_Y),
                                                 epochs=epochs, steps_per_epoch=steps_per_epoch,
                                                 callbacks=self.callbacks, verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Esempio n. 25
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Esempio n. 26
0
    def train(self,
              X,
              Y,
              validation_data,
              classes='auto',
              augmenter=None,
              seed=None,
              epochs=None,
              steps_per_epoch=None,
              workers=1):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
            Input images
        Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
            Label masks
        classes (optional): 'auto' or iterable of same length as X
             label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0)
             list of dicts with label id -> class id (1,...,n_classes)
             'auto' -> all objects will be assigned to the first non-background class,
                       or will be ignored if config.n_classes is None
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass)
            Tuple (triple if multiclass) of X,Y,[classes] validation data.
        augmenter : None or callable
            Function with expected signature ``xt, yt = augmenter(x, y)``
            that takes in a single pair of input/label image (x,y) and returns
            the transformed images (xt, yt) for the purpose of data augmentation
            during training. Not applied to validation images.
            Example:
            def simple_augmenter(x,y):
                x = x + 0.05*np.random.normal(0,1,x.shape)
                return x,y
        seed : int
            Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """
        if seed is not None:
            # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
            np.random.seed(seed)
        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        classes = self._parse_classes_arg(classes, len(X))

        if not self._is_multiclass() and classes is not None:
            warnings.warn("Ignoring given classes as n_classes is set to None")

        isinstance(validation_data, (list, tuple)) or _raise(ValueError())
        if self._is_multiclass() and len(validation_data) == 2:
            validation_data = tuple(validation_data) + ('auto', )
        ((len(validation_data) == (3 if self._is_multiclass(
        ) else 2)) or _raise(
            ValueError(
                f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}'
            )))

        patch_size = self.config.train_patch_size
        axes = self.config.axes.replace('C', '')
        b = self.config.train_completion_crop if self.config.train_shape_completion else 0
        div_by = self._axes_div_by(axes)
        [(p - 2 * b) % d == 0 or _raise(
            ValueError(
                "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'"
                .format(a=a, d=d) if self.config.train_shape_completion else
                "'train_patch_size' must be divisible by {d} along axis '{a}'".
                format(a=a, d=d)))
         for p, d, a in zip(patch_size, div_by, axes)]

        if not self._model_prepared:
            self.prepare_for_training()

        data_kwargs = dict(
            n_rays=self.config.n_rays,
            patch_size=self.config.train_patch_size,
            grid=self.config.grid,
            shape_completion=self.config.train_shape_completion,
            b=self.config.train_completion_crop,
            use_gpu=self.config.use_gpu,
            foreground_prob=self.config.train_foreground_only,
            n_classes=self.config.n_classes,
            sample_ind_cache=self.config.train_sample_cache,
        )

        # generate validation data and store in numpy arrays
        n_data_val = len(validation_data[0])
        classes_val = self._parse_classes_arg(
            validation_data[2], n_data_val) if self._is_multiclass() else None
        n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
        _data_val = StarDistData2D(validation_data[0],
                                   validation_data[1],
                                   classes=classes_val,
                                   batch_size=n_take,
                                   length=1,
                                   **data_kwargs)
        data_val = _data_val[0]

        # expose data generator as member for general diagnostics
        self.data_train = StarDistData2D(
            X,
            Y,
            classes=classes,
            batch_size=self.config.train_batch_size,
            augmenter=augmenter,
            length=epochs * steps_per_epoch,
            **data_kwargs)

        if self.config.train_tensorboard:
            # show dist for three rays
            _n = min(3, self.config.n_rays)
            channel = axes_dict(self.config.axes)['C']
            output_slices = [[slice(None)] * 4, [slice(None)] * 4]
            output_slices[1][1 + channel] = slice(
                0, (self.config.n_rays // _n) * _n, self.config.n_rays // _n)
            if self._is_multiclass():
                _n = min(3, self.config.n_classes)
                output_slices += [[slice(None)] * 4]
                output_slices[2][1 + channel] = slice(
                    1, 1 + ((self.config.n_classes + 1) // _n) * _n,
                    self.config.n_classes // _n)

            if IS_TF_1:
                for cb in self.callbacks:
                    if isinstance(cb, CARETensorBoard):
                        cb.output_slices = output_slices
                        # target image for dist includes dist_mask and thus has more channels than dist output
                        cb.output_target_shapes = [None, [None] * 4, None]
                        cb.output_target_shapes[1][
                            1 + channel] = data_val[1][1].shape[1 + channel]
            elif self.basedir is not None and not any(
                    isinstance(cb, CARETensorBoardImage)
                    for cb in self.callbacks):
                self.callbacks.append(
                    CARETensorBoardImage(model=self.keras_model,
                                         data=data_val,
                                         log_dir=str(self.logdir / 'logs' /
                                                     'images'),
                                         n_images=3,
                                         prob_out=False,
                                         output_slices=output_slices))

        fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
        history = fit(
            iter(self.data_train),
            validation_data=data_val,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            workers=workers,
            use_multiprocessing=workers > 1,
            callbacks=self.callbacks,
            verbose=1,
            # set validation batchsize to training batchsize (only works for tf >= 2.2)
            **(dict(validation_batch_size=self.config.train_batch_size)
               if _tf_version_at_least("2.2.0") else {}))
        self._training_finished()

        return history
Esempio n. 27
0
    def train(self, X, Y, validation_data, augmenter=None, seed=None, epochs=None, steps_per_epoch=None, multi=False, ncpu=1):
        """Train the neural network with the given data.

        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of input images.
        Y : :class:`numpy.ndarray`
            Array of label masks.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of X,Y validation arrays.
        augmenter : None or callable
            Function with expected signature ``xt, yt = augmenter(x, y)``
            that takes in a single pair of input/label image (x,y) and returns
            the transformed images (xt, yt) for the purpose of data augmentation
            during training. Not applied to validation images.
            Example:
            def simple_augmenter(x,y):
                x = x + 0.05*np.random.normal(0,1,x.shape)
                return x,y
        seed : int
            Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.

        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.

        """
        if seed is not None:
            # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
            np.random.seed(seed)
        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        validation_data is not None or _raise(ValueError())
        ((isinstance(validation_data,(list,tuple)) and len(validation_data)==2)
            or _raise(ValueError('validation_data must be a pair of numpy arrays')))

        patch_size = self.config.train_patch_size
        axes = self.config.axes.replace('C','')
        b = self.config.train_completion_crop if self.config.train_shape_completion else 0
        div_by = self._axes_div_by(axes)
        [(p-2*b) % d == 0 or _raise(ValueError(
            "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else
            "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
         )) for p,d,a in zip(patch_size,div_by,axes)]

        if not self._model_prepared:
            self.prepare_for_training()

        data_kwargs = dict (
            n_rays           = self.config.n_rays,
            patch_size       = self.config.train_patch_size,
            grid             = self.config.grid,
            shape_completion = self.config.train_shape_completion,
            b                = self.config.train_completion_crop,
            use_gpu          = self.config.use_gpu,
            foreground_prob  = self.config.train_foreground_only,
            prob_thr         = self.config.EDT_prob_threshold,
            border_R         = self.config.EDT_border_R,
            black_border     = self.config.EDT_black_border
            )
        # generate validation data and store in numpy arrays
        data_val = StarDistData2D(*validation_data, batch_size=1, **data_kwargs)
        n_data_val = len(data_val)
        n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
        ids = tuple(np.random.choice(n_data_val, size=n_take, replace=(n_take > n_data_val)))
        Xv, Mv, Pv, Dv = [None]*n_take, [None]*n_take, [None]*n_take, [None]*n_take
        for i,k in enumerate(ids):
            (Xv[i],Mv[i]),(Pv[i],Dv[i]) = data_val[k]
        Xv, Mv, Pv, Dv = np.concatenate(Xv,axis=0), np.concatenate(Mv,axis=0), np.concatenate(Pv,axis=0), np.concatenate(Dv,axis=0)
        data_val = [[Xv,Mv],[Pv,Dv]]

        data_train = StarDistData2D(X, Y, batch_size=self.config.train_batch_size, augmenter=augmenter, **data_kwargs)
                
        for cb in self.callbacks:
            if isinstance(cb,CARETensorBoard):
                # show dist for three rays
                _n = min(3, self.config.n_rays)
                cb.output_slices = [[slice(None)]*4,[slice(None)]*4]
                cb.output_slices[1][1+axes_dict(self.config.axes)['C']] = slice(0,(self.config.n_rays//_n)*_n,self.config.n_rays//_n)

        history = self.keras_model.fit_generator(generator=data_train, validation_data=data_val,
                                                 epochs=epochs, steps_per_epoch=steps_per_epoch,
                                                 callbacks=self.callbacks, verbose=1, use_multiprocessing=multi, workers=ncpu)
        self._training_finished()

        return history
Esempio n. 28
0
    def predict_instances_big(self, img, axes, block_size, min_overlap, context=None,
                              labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs):
        """Predict instance segmentation from very large input images.

        Intended to be used when `predict_instances` cannot be used due to memory limitations.
        This function will break the input image into blocks and process them individually
        via `predict_instances` and assemble all the partial results. If used as intended, the result
        should be the same as if `predict_instances` was used directly on the whole image.

        **Important**: The crucial assumption is that all predicted object instances are smaller than
                       the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size.

        Example
        -------
        >>> img.shape
        (20000, 20000)
        >>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096,
                                                        min_overlap=128, context=128, n_tiles=(4,4))

        Parameters
        ----------
        img: :class:`numpy.ndarray` or similar
            Input image
        axes: str
            Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.)
        block_size: int or iterable of int
            Process input image in blocks of the provided shape.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        min_overlap: int or iterable of int
            Amount of guaranteed overlap between blocks.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        context: int or iterable of int, or None
            Amount of image context on all sides of a block, which is discarded.
            If None, uses an automatic estimate that should work in many cases.
            (If a scalar value is given, it is used for all spatial image dimensions.)
        labels_out: :class:`numpy.ndarray` or similar, or None, or False
            numpy array or similar (must be of correct shape) to which the label image is written.
            If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``.
            If False, will not write the label image (useful if only the dictionary is needed).
        labels_out_dtype: str or dtype
            Data type of returned label image if ``labels_out=None`` (has no effect otherwise).
        show_progress: bool
            Show progress bar for block processing.
        kwargs: dict
            Keyword arguments for ``predict_instances``.

        Returns
        -------
        (:class:`numpy.ndarray` or False, dict)
            Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra.

        """
        from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
        from ..matching import relabel_sequential

        n = img.ndim
        axes = axes_check_and_normalize(axes, length=n)
        grid = self._axes_div_by(axes)
        axes_out = self._axes_out.replace('C','')
        shape_dict = dict(zip(axes,img.shape))
        shape_out = tuple(shape_dict[a] for a in axes_out)

        if context is None:
            context = self._axes_tile_overlap(axes)

        if np.isscalar(block_size):  block_size  = n*[block_size]
        if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
        if np.isscalar(context):     context     = n*[context]
        block_size, min_overlap, context = list(block_size), list(min_overlap), list(context)
        assert n == len(block_size) == len(min_overlap) == len(context)

        if 'C' in axes:
            # single block for channel axis
            i = axes_dict(axes)['C']
            # if (block_size[i], min_overlap[i], context[i]) != (None, None, None):
            #     print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " +
            #           "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True)
            block_size[i] = img.shape[i]
            min_overlap[i] = context[i] = 0

        block_size  = tuple(_grid_divisible(g, v, name='block_size',  verbose=False) for v,g,a in zip(block_size, grid,axes))
        min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
        context     = tuple(_grid_divisible(g, v, name='context',     verbose=False) for v,g,a in zip(context,    grid,axes))

        # print(f"input: shape {img.shape} with axes {axes}")
        print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)

        for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)):
            if c < o:
                print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True)

        # create block cover
        blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid)

        if np.isscalar(labels_out) and bool(labels_out) is False:
            labels_out = None
        else:
            if labels_out is None:
                labels_out = np.zeros(shape_out, dtype=labels_out_dtype)
            else:
                labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out})."))

        polys_all = {}
        # problem_ids = []
        label_offset = 1

        kwargs_override = dict(axes=axes, overlap_label=None)
        if show_progress:
            kwargs_override['show_tile_progress'] = False # disable progress for predict_instances
        for k,v in kwargs_override.items():
            if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True)
            kwargs[k] = v

        blocks = tqdm(blocks, disable=(not show_progress))
        # actual computation
        for block in blocks:
            labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs)
            labels = block.crop_context(labels, axes=axes_out)
            labels, polys = block.filter_objects(labels, polys, axes=axes_out)
            # TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset)
            labels = relabel_sequential(labels, label_offset)[0]
            # labels, fwd_map, _ = relabel_sequential(labels, label_offset)
            # if len(incomplete) > 0:
            #     problem_ids.extend([fwd_map[i] for i in incomplete])
            #     if show_progress:
            #         blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}")
            if labels_out is not None:
                block.write(labels_out, labels, axes=axes_out)
            for k,v in polys.items():
                polys_all.setdefault(k,[]).append(v)
            label_offset += len(polys['prob'])

        polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}

        # if labels_out is not None and len(problem_ids) > 0:
        #     # if show_progress:
        #     #     blocks.write('')
        #     # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True)
        #     repaint_labels(labels_out, problem_ids, polys_all, show_progress=False)

        return labels_out, polys_all#, tuple(problem_ids)
Esempio n. 29
0
    def predict(self,
                img,
                axes=None,
                normalizer=None,
                n_tiles=None,
                show_tile_progress=True,
                **predict_kwargs):
        """Predict.

        Parameters
        ----------
        img : :class:`numpy.ndarray`
            Input image
        axes : str or None
            Axes of the input ``img``.
            ``None`` denotes that axes of img are the same as denoted in the config.
        normalizer : :class:`csbdeep.data.Normalizer` or None
            (Optional) normalization of input image before prediction.
            Note that the default (``None``) assumes ``img`` to be already normalized.
        n_tiles : iterable or None
            Out of memory (OOM) errors can occur if the input image is too large.
            To avoid this problem, the input image is broken up into (overlapping) tiles
            that are processed independently and re-assembled.
            This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
            ``None`` denotes that no tiling should be used.
        show_tile_progress: bool
            Whether to show progress during tiled prediction.
        predict_kwargs: dict
            Keyword arguments for ``predict`` function of Keras model.

        Returns
        -------
        (:class:`numpy.ndarray`,:class:`numpy.ndarray`)
            Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances.

        """
        if n_tiles is None:
            n_tiles = [1] * img.ndim
        try:
            n_tiles = tuple(n_tiles)
            img.ndim == len(n_tiles) or _raise(TypeError())
        except TypeError:
            raise ValueError("n_tiles must be an iterable of length %d" %
                             img.ndim)
        all(np.isscalar(t) and 1 <= t and int(t) == t
            for t in n_tiles) or _raise(
                ValueError(
                    "all values of n_tiles must be integer values >= 1"))
        n_tiles = tuple(map(int, n_tiles))

        axes = self._normalize_axes(img, axes)
        axes_net = self.config.axes

        _permute_axes = self._make_permute_axes(axes, axes_net)
        x = _permute_axes(img)  # x has axes_net semantics

        channel = axes_dict(axes_net)['C']
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
        axes_net_div_by = self._axes_div_by(axes_net)

        grid = tuple(self.config.grid)
        len(grid) == len(axes_net) - 1 or _raise(ValueError())
        grid_dict = dict(zip(axes_net.replace('C', ''), grid))

        normalizer = self._check_normalizer_resizer(normalizer, None)[0]
        resizer = StarDistPadAndCropResizer(grid=grid_dict)

        x = normalizer.before(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)

        def predict_direct(tile):
            sh = list(tile.shape)
            sh[channel] = 1
            dummy = np.empty(sh, np.float32)
            prob, dist = self.keras_model.predict(
                [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs)
            return prob[0], dist[0]

        if np.prod(n_tiles) > 1:
            tiling_axes = axes_net.replace('C', '')  # axes eligible for tiling
            x_tiling_axis = tuple(
                axes_dict(axes_net)[a]
                for a in tiling_axes)  # numerical axis ids for x
            axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
            # hack: permute tiling axis in the same way as img -> x was permuted
            n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape
            (all(n_tiles[i] == 1
                 for i in range(x.ndim) if i not in x_tiling_axis)
             or _raise(
                 ValueError("entry of n_tiles > 1 only allowed for axes '%s'" %
                            tiling_axes)))

            sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)]
            sh[channel] = 1
            prob = np.empty(sh, np.float32)
            sh[channel] = self.config.n_rays
            dist = np.empty(sh, np.float32)

            n_block_overlaps = [
                int(np.ceil(overlap / blocksize)) for overlap, blocksize in
                zip(axes_net_tile_overlaps, axes_net_div_by)
            ]

            for tile, s_src, s_dst in tqdm(tile_iterator(
                    x,
                    n_tiles,
                    block_sizes=axes_net_div_by,
                    n_block_overlaps=n_block_overlaps),
                                           disable=(not show_tile_progress),
                                           total=np.prod(n_tiles)):
                prob_tile, dist_tile = predict_direct(tile)
                # account for grid
                s_src = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_src, axes_net)
                ]
                s_dst = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_dst, axes_net)
                ]
                # prob and dist have different channel dimensionality than image x
                s_src[channel] = slice(None)
                s_dst[channel] = slice(None)
                s_src, s_dst = tuple(s_src), tuple(s_dst)
                # print(s_src,s_dst)
                prob[s_dst] = prob_tile[s_src]
                dist[s_dst] = dist_tile[s_src]

        else:
            prob, dist = predict_direct(x)

        prob = resizer.after(prob, axes_net)
        dist = resizer.after(dist, axes_net)
        dist = np.maximum(
            1e-3, dist
        )  # avoid small/negative dist values to prevent problems with Qhull

        prob = np.take(prob, 0, axis=channel)
        dist = np.moveaxis(dist, channel, -1)

        return prob, dist
Esempio n. 30
0
 def _no_tiling_for_axis(axes_image, n_tiles, axis):
     if n_tiles is not None and axis in axes_image:
         return n_tiles[axes_dict(axes_image)[axis]] == 1
     return True