Exemplo n.º 1
0
    def load_imgs(self, files, dims='YX'):
        """
        Helper to read a list of files. The images are not required to have same size,
        but have to be of same dimensionality.

        Parameters
        ----------
        files  : list(String)
                 List of paths to tiff-files.
        dims   : String, optional(default='YX')
                 Dimensions of the images to read. Known dimensions are: 'TZYXC'

        Returns
        -------
        images : list(array(float))
                 A list of the read tif-files. The images have dimensionality 'SZYXC' or 'SYXC'
        """
        assert 'Y' in dims and 'X' in dims, "'dims' has to contain 'X' and 'Y'."

        tmp_dims = dims
        for b in ['X', 'Y', 'Z', 'T', 'C']:
            assert tmp_dims.count(b) <= 1, "'dims' has to contain {} at most once.".format(b)
            tmp_dims = tmp_dims.replace(b, '')

        assert len(tmp_dims) == 0, "Unknown dimensions in 'dims'."

        if 'Z' in dims:
            net_axes = 'ZYXC'
        else:
            net_axes = 'YXC'

        move_axis_from = ()
        move_axis_to = ()
        for d, b in enumerate(dims):
            move_axis_from += tuple([d])
            if b == 'T':
                move_axis_to += tuple([0])
            elif b == 'C':
                move_axis_to += tuple([-1])
            elif b in 'XYZ':
                if 'T' in dims:
                    move_axis_to += tuple([net_axes.index(b)+1])
                else:
                    move_axis_to += tuple([net_axes.index(b)])
        imgs = []
        for f in files:
            if f.endswith('.tif') or f.endswith('.tiff'):
                imread = tifffile.imread
            elif f.endswith('.png'):
                imread = image.imread
            elif f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.JPEG') or f.endswith('.JPG'):
                _raise(Exception("JPEG is not supported, because it is not loss-less and breaks the pixel-wise independence assumption."))
            else:
                _raise("Filetype '{}' is not supported.".format(f))

            img = imread(f).astype(np.float32)
            assert len(img.shape) == len(dims), "Number of image dimensions doesn't match 'dims'."

            img = np.moveaxis(img, move_axis_from, move_axis_to)

            if not ('T' in dims):    
                img = img[np.newaxis]

            if not ('C' in dims):
                img = img[..., np.newaxis]

            imgs.append(img)

        return imgs
Exemplo n.º 2
0
    def predict(self,
                img,
                axes=None,
                normalizer=None,
                n_tiles=None,
                show_tile_progress=True,
                **predict_kwargs):
        """Predict.

        Parameters
        ----------
        img : :class:`numpy.ndarray`
            Input image
        axes : str or None
            Axes of the input ``img``.
            ``None`` denotes that axes of img are the same as denoted in the config.
        normalizer : :class:`csbdeep.data.Normalizer` or None
            (Optional) normalization of input image before prediction.
            Note that the default (``None``) assumes ``img`` to be already normalized.
        n_tiles : iterable or None
            Out of memory (OOM) errors can occur if the input image is too large.
            To avoid this problem, the input image is broken up into (overlapping) tiles
            that are processed independently and re-assembled.
            This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
            ``None`` denotes that no tiling should be used.
        show_tile_progress: bool
            Whether to show progress during tiled prediction.
        predict_kwargs: dict
            Keyword arguments for ``predict`` function of Keras model.

        Returns
        -------
        (:class:`numpy.ndarray`,:class:`numpy.ndarray`)
            Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances.

        """
        if n_tiles is None:
            n_tiles = [1] * img.ndim
        try:
            n_tiles = tuple(n_tiles)
            img.ndim == len(n_tiles) or _raise(TypeError())
        except TypeError:
            raise ValueError("n_tiles must be an iterable of length %d" %
                             img.ndim)
        all(np.isscalar(t) and 1 <= t and int(t) == t
            for t in n_tiles) or _raise(
                ValueError(
                    "all values of n_tiles must be integer values >= 1"))
        n_tiles = tuple(map(int, n_tiles))

        axes = self._normalize_axes(img, axes)
        axes_net = self.config.axes

        _permute_axes = self._make_permute_axes(axes, axes_net)
        x = _permute_axes(img)  # x has axes_net semantics

        channel = axes_dict(axes_net)['C']
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
        axes_net_div_by = self._axes_div_by(axes_net)

        grid = tuple(self.config.grid)
        len(grid) == len(axes_net) - 1 or _raise(ValueError())
        grid_dict = dict(zip(axes_net.replace('C', ''), grid))

        normalizer = self._check_normalizer_resizer(normalizer, None)[0]
        resizer = StarDistPadAndCropResizer(grid=grid_dict)

        x = normalizer.before(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)

        def predict_direct(tile):
            sh = list(tile.shape)
            sh[channel] = 1
            dummy = np.empty(sh, np.float32)
            prob, dist = self.keras_model.predict(
                [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs)
            return prob[0], dist[0]

        if np.prod(n_tiles) > 1:
            tiling_axes = axes_net.replace('C', '')  # axes eligible for tiling
            x_tiling_axis = tuple(
                axes_dict(axes_net)[a]
                for a in tiling_axes)  # numerical axis ids for x
            axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
            # hack: permute tiling axis in the same way as img -> x was permuted
            n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape
            (all(n_tiles[i] == 1
                 for i in range(x.ndim) if i not in x_tiling_axis)
             or _raise(
                 ValueError("entry of n_tiles > 1 only allowed for axes '%s'" %
                            tiling_axes)))

            sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)]
            sh[channel] = 1
            prob = np.empty(sh, np.float32)
            sh[channel] = self.config.n_rays
            dist = np.empty(sh, np.float32)

            n_block_overlaps = [
                int(np.ceil(overlap / blocksize)) for overlap, blocksize in
                zip(axes_net_tile_overlaps, axes_net_div_by)
            ]

            for tile, s_src, s_dst in tqdm(tile_iterator(
                    x,
                    n_tiles,
                    block_sizes=axes_net_div_by,
                    n_block_overlaps=n_block_overlaps),
                                           disable=(not show_tile_progress),
                                           total=np.prod(n_tiles)):
                prob_tile, dist_tile = predict_direct(tile)
                # account for grid
                s_src = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_src, axes_net)
                ]
                s_dst = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_dst, axes_net)
                ]
                # prob and dist have different channel dimensionality than image x
                s_src[channel] = slice(None)
                s_dst[channel] = slice(None)
                s_src, s_dst = tuple(s_src), tuple(s_dst)
                # print(s_src,s_dst)
                prob[s_dst] = prob_tile[s_src]
                dist[s_dst] = dist_tile[s_src]

        else:
            prob, dist = predict_direct(x)

        prob = resizer.after(prob, axes_net)
        dist = resizer.after(dist, axes_net)
        dist = np.maximum(
            1e-3, dist
        )  # avoid small/negative dist values to prevent problems with Qhull

        prob = np.take(prob, 0, axis=channel)
        dist = np.moveaxis(dist, channel, -1)

        return prob, dist
Exemplo n.º 3
0
    def __init__(self,
                 X,
                 Y,
                 n_rays,
                 grid,
                 batch_size,
                 patch_size,
                 use_gpu=False,
                 maxfilter_cache=True,
                 maxfilter_patch_size=None,
                 augmenter=None):

        X = [x.astype(np.float32, copy=False) for x in X]
        # Y = [y.astype(np.uint16,  copy=False) for y in Y]

        # sanity checks
        assert len(X) == len(Y) and len(X) > 0
        nD = len(patch_size)
        assert nD in (2, 3)
        x_ndim = X[0].ndim
        assert x_ndim in (nD, nD + 1)
        assert all(
            y.ndim == nD and x.ndim == x_ndim and x.shape[:nD] == y.shape
            for x, y in zip(X, Y))
        if x_ndim == nD:
            self.n_channel = None
        else:
            self.n_channel = X[0].shape[-1]
            assert all(x.shape[-1] == self.n_channel for x in X)

        self.X, self.Y = X, Y
        self.batch_size = batch_size
        self.n_rays = n_rays
        self.patch_size = patch_size
        self.ss_grid = (slice(None), ) + tuple(slice(0, None, g) for g in grid)
        self.perm = np.random.permutation(len(self.X))
        self.use_gpu = bool(use_gpu)
        if augmenter is None:
            augmenter = lambda *args: args
        callable(augmenter) or _raise(
            ValueError("augmenter must be None or callable"))
        self.augmenter = augmenter

        if self.use_gpu:
            from gputools import max_filter
            self.max_filter = lambda y, patch_size: max_filter(
                y.astype(np.float32), patch_size)
        else:
            from scipy.ndimage.filters import maximum_filter
            self.max_filter = lambda y, patch_size: maximum_filter(
                y, patch_size, mode='constant')

        self.maxfilter_patch_size = maxfilter_patch_size if maxfilter_patch_size is not None else self.patch_size

        if maxfilter_cache:
            self.R = [
                self.no_background_patches((y, x))
                for x, y in zip(self.X, self.Y)
            ]
        else:
            self.R = None
Exemplo n.º 4
0
    def __init__(self, X, **kwargs):
        
        # X is empty if config is None
        if (X.size != 0):
    
            assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."
    
            n_dim = len(X.shape) - 2
            n_channel_in = X.shape[-1]
            n_channel_out = n_channel_in

            means, stds = [], []
            for i in range(n_channel_in):
                means.append(np.mean(X[...,i]))
                stds.append(np.std(X[...,i]))

            if n_dim == 2:
                axes = 'SYXC'
            elif n_dim == 3:
                axes = 'SZYXC'
    
            # parse and check axes
            axes = axes_check_and_normalize(axes)
            ax = axes_dict(axes)
            ax = {a: (ax[a] is not None) for a in ax}
    
            (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
            not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))
    
            axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
            axes = axes.replace('S','') # remove sample axis if it exists
    
            if backend_channels_last():
                if ax['C']:
                    axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
                else:
                    axes += 'C'
            else:
                if ax['C']:
                    axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
                else:
                    axes = 'C'+axes
    
            # normalization parameters
            self.means                 = [str(el) for el in means]
            self.stds                  = [str(el) for el in stds]
            # directly set by parameters
            self.n_dim                 = n_dim
            self.axes                  = axes
            self.n_channel_in          = int(n_channel_in)
            self.n_channel_out         = int(n_channel_out)
    
            # default config (can be overwritten by kwargs below)
            self.unet_residual         = False
            self.unet_n_depth          = 2
            self.unet_kern_size        = 5 if self.n_dim==2 else 3
            self.unet_n_first          = 32
            self.unet_last_activation  = 'linear'
            if backend_channels_last():
                self.unet_input_shape  = self.n_dim*(None,) + (self.n_channel_in,)
            else:
                self.unet_input_shape  = (self.n_channel_in,) + self.n_dim*(None,)
    
            self.train_loss            = 'mae'
            self.train_epochs          = 100
            self.train_steps_per_epoch = 400
            self.train_learning_rate   = 0.0004
            self.train_batch_size      = 16
            self.train_tensorboard     = True
            self.train_checkpoint      = 'weights_best.h5'
            self.train_reduce_lr       = {'factor': 0.5, 'patience': 10}
            self.batch_norm            = True
            self.n2v_perc_pix           = 1.5
            self.n2v_patch_shape       = (64, 64) if self.n_dim==2 else (64, 64, 64)
            self.n2v_manipulator       = 'uniform_withCP'
            self.n2v_neighborhood_radius = 5

            # disallow setting 'n_dim' manually
            try:
                del kwargs['n_dim']
                # warnings.warn("ignoring parameter 'n_dim'")
            except:
                pass
            
        self.probabilistic         = False

        for k in kwargs:
            setattr(self, k, kwargs[k])
Exemplo n.º 5
0
def _cpp_star_dist(a, n_rays=32):
    (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
    return c_star_dist(a.astype(np.uint16, copy=False), int(n_rays))
Exemplo n.º 6
0
    def train(self,
              X,
              Y,
              validation_data,
              epochs=None,
              steps_per_epoch=None,
              numGPU=1):
        """Train the neural network with the given data.
        Parameters
        ----------
        X : :class:`numpy.ndarray`
            Array of source images.
        Y : :class:`numpy.ndarray`
            Array of target images.
        validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
            Tuple of arrays for source and target validation images.
        epochs : int
            Optional argument to use instead of the value from ``config``.
        steps_per_epoch : int
            Optional argument to use instead of the value from ``config``.
        Returns
        -------
        ``History`` object
            See `Keras training history <https://keras.io/models/model/#fit>`_.
        """

        if numGPU > 1:  # if more than 1 gpu is requested, use multimodel
            print('Using multiple GPUs for training')
            self.keras_model = multi_gpu_model(self.keras_model,
                                               gpus=numGPU,
                                               cpu_merge=True,
                                               cpu_relocation=False)

        ((isinstance(validation_data,
                     (list, tuple)) and len(validation_data) == 2) or
         _raise(ValueError('validation_data must be a pair of numpy arrays')))

        n_train, n_val = len(X), len(validation_data[0])
        frac_val = (1.0 * n_val) / (n_train + n_val)
        frac_warn = 0.05
        if frac_val < frac_warn:
            warnings.warn(
                "small number of validation images (only %.1f%% of all images)"
                % (100 * frac_val))

        axes = axes_check_and_normalize('S' + self.config.axes, X.ndim)
        ax = axes_dict(axes)

        for a, div_by in zip(axes, self._axes_div_by(axes)):
            n = X.shape[ax[a]]
            if n % div_by != 0:
                raise ValueError(
                    "training images must be evenly divisible by %d along axis %s"
                    " (which has incompatible size %d)" % (div_by, a, n))

        if epochs is None:
            epochs = self.config.train_epochs
        if steps_per_epoch is None:
            steps_per_epoch = self.config.train_steps_per_epoch

        if not self._model_prepared:
            self.prepare_for_training()

        training_data = CryoDataWrapper(X, Y, self.config.train_batch_size,
                                        self.config.n_dim)

        history = self.keras_model.fit_generator(
            generator=training_data,
            validation_data=validation_data,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            callbacks=self.callbacks,
            verbose=1)

        if self.basedir is not None:
            self.keras_model.save_weights(str(self.logdir / 'weights_last.h5'))

            if self.config.train_checkpoint is not None:
                print()
                self._find_and_load_weights(self.config.train_checkpoint)
                try:
                    # remove temporary weights
                    (self.logdir / 'weights_now.h5').unlink()
                except FileNotFoundError:
                    pass

        return history
Exemplo n.º 7
0
def matching_dataset_lazy(y_gen,
                          thresh=0.5,
                          criterion='iou',
                          by_image=False,
                          show_progress=True,
                          parallel=False):

    expected_keys = set(
        ('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1',
         'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score'))

    single_thresh = False
    if np.isscalar(thresh):
        single_thresh = True
        thresh = (thresh, )

    tqdm_kwargs = {}
    tqdm_kwargs['disable'] = not bool(show_progress)
    if int(show_progress) > 1:
        tqdm_kwargs['total'] = int(show_progress)

    # compute matching stats for every pair of label images
    if parallel:
        from concurrent.futures import ThreadPoolExecutor
        fn = lambda pair: matching(
            *pair, thresh=thresh, criterion=criterion, report_matches=False)
        with ThreadPoolExecutor() as pool:
            stats_all = tuple(pool.map(fn, tqdm(y_gen, **tqdm_kwargs)))
    else:
        stats_all = tuple(
            matching(y_t,
                     y_p,
                     thresh=thresh,
                     criterion=criterion,
                     report_matches=False)
            for y_t, y_p in tqdm(y_gen, **tqdm_kwargs))

    # accumulate results over all images for each threshold separately
    n_images, n_threshs = len(stats_all), len(thresh)
    accumulate = [{} for _ in range(n_threshs)]
    for stats in stats_all:
        for i, s in enumerate(stats):
            acc = accumulate[i]
            for k, v in s._asdict().items():
                if k == 'mean_true_score' and not bool(by_image):
                    # convert mean_true_score to "sum_true_score"
                    acc[k] = acc.setdefault(k, 0) + v * s.n_true
                else:
                    try:
                        acc[k] = acc.setdefault(k, 0) + v
                    except TypeError:
                        pass

    # normalize/compute 'precision', 'recall', 'accuracy', 'f1'
    for thr, acc in zip(thresh, accumulate):
        set(acc.keys()) == expected_keys or _raise(
            ValueError("unexpected keys"))
        acc['criterion'] = criterion
        acc['thresh'] = thr
        acc['by_image'] = bool(by_image)
        if bool(by_image):
            for k in ('precision', 'recall', 'accuracy', 'f1',
                      'mean_true_score'):
                acc[k] /= n_images
        else:
            tp, fp, fn = acc['tp'], acc['fp'], acc['fn']
            acc.update(
                precision=precision(tp, fp, fn),
                recall=recall(tp, fp, fn),
                accuracy=accuracy(tp, fp, fn),
                f1=f1(tp, fp, fn),
                mean_true_score=acc['mean_true_score'] /
                acc['n_true'] if acc['n_true'] > 0 else 0.0,
            )

    accumulate = tuple(
        namedtuple('DatasetMatching', acc.keys())(*acc.values())
        for acc in accumulate)
    return accumulate[0] if single_thresh else accumulate
Exemplo n.º 8
0
def matching(y_true,
             y_pred,
             thresh=0.5,
             criterion='iou',
             report_matches=False):
    """
    if report_matches=True, return (matched_pairs,matched_scores) are independent of 'thresh'
    """
    _check_label_array(y_true, 'y_true')
    _check_label_array(y_pred, 'y_pred')
    y_true.shape == y_pred.shape or _raise(
        ValueError(
            "y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes"
            .format(y_true=y_true, y_pred=y_pred)))
    criterion in matching_criteria or _raise(
        ValueError("Matching criterion '%s' not supported." % criterion))
    if thresh is None: thresh = 0
    thresh = float(thresh) if np.isscalar(thresh) else map(float, thresh)

    y_true, _, map_rev_true = relabel_sequential(y_true)
    y_pred, _, map_rev_pred = relabel_sequential(y_pred)

    overlap = label_overlap(y_true, y_pred, check=False)
    scores = matching_criteria[criterion](overlap)
    assert 0 <= np.min(scores) <= np.max(scores) <= 1

    # ignoring background
    scores = scores[1:, 1:]
    n_true, n_pred = scores.shape
    n_matched = min(n_true, n_pred)

    def _single(thr):
        not_trivial = n_matched > 0 and np.any(scores >= thr)
        if not_trivial:
            # compute optimal matching with scores as tie-breaker
            costs = -(scores >= thr).astype(float) - scores / (2 * n_matched)
            true_ind, pred_ind = linear_sum_assignment(costs)
            assert n_matched == len(true_ind) == len(pred_ind)
            match_ok = scores[true_ind, pred_ind] >= thr
            tp = np.count_nonzero(match_ok)
        else:
            tp = 0
        fp = n_pred - tp
        fn = n_true - tp
        # assert tp+fp == n_pred
        # assert tp+fn == n_true
        stats_dict = dict(
            criterion=criterion,
            thresh=thr,
            fp=fp,
            tp=tp,
            fn=fn,
            precision=precision(tp, fp, fn),
            recall=recall(tp, fp, fn),
            accuracy=accuracy(tp, fp, fn),
            f1=f1(tp, fp, fn),
            n_true=n_true,
            n_pred=n_pred,
            mean_true_score=np.sum(scores[true_ind, pred_ind][match_ok]) /
            n_true if not_trivial else 0.0,
        )
        if bool(report_matches):
            if not_trivial:
                stats_dict.update(
                    # int() to be json serializable
                    matched_pairs=tuple(
                        (int(map_rev_true[i]), int(map_rev_pred[j]))
                        for i, j in zip(1 + true_ind, 1 + pred_ind)),
                    matched_scores=tuple(scores[true_ind, pred_ind]),
                    matched_tps=tuple(map(int, np.flatnonzero(match_ok))),
                )
            else:
                stats_dict.update(
                    matched_pairs=(),
                    matched_scores=(),
                    matched_tps=(),
                )
        return namedtuple('Matching', stats_dict.keys())(*stats_dict.values())

    return _single(thresh) if np.isscalar(thresh) else tuple(
        map(_single, thresh))