예제 #1
0
def test_tile_iterator(guarantee, n_dims):
    rng = np.random.RandomState(42)
    for _ in range(10):
        n = rng.randint(low=10, high=300, size=n_dims)
        n_blocks = list(rng.randint(low=1, high=10, size=n_dims))
        block_size = [_n // _n_blocks for _n_blocks, _n in zip(n_blocks, n)]
        n = [
            _block_size * (_n // _block_size)
            for _block_size, _n in zip(block_size, n)
        ]
        n_block_overlap = [
            rng.randint(low=0, high=_n_blocks + 1) for _n_blocks in n_blocks
        ]
        n_tiles = [
            rng.randint(low=1, high=_n_blocks + 1) for _n_blocks in n_blocks
        ]

        x = rng.uniform(size=n)
        y = np.empty_like(x)
        c = 0
        actual_n_tiles = total_n_tiles(x,
                                       n_tiles,
                                       block_size,
                                       n_block_overlap,
                                       guarantee=guarantee)
        for tile, s_src, s_dst in tile_iterator(x, n_tiles, block_size,
                                                n_block_overlap, guarantee):
            y[s_dst] = tile[s_src]
            c += 1

        assert c == actual_n_tiles
        assert np.allclose(x, y)
예제 #2
0
파일: base.py 프로젝트: gatoniel/stardist
    def predict(self,
                img,
                axes=None,
                normalizer=None,
                n_tiles=None,
                show_tile_progress=True,
                **predict_kwargs):
        """Predict.

        Parameters
        ----------
        img : :class:`numpy.ndarray`
            Input image
        axes : str or None
            Axes of the input ``img``.
            ``None`` denotes that axes of img are the same as denoted in the config.
        normalizer : :class:`csbdeep.data.Normalizer` or None
            (Optional) normalization of input image before prediction.
            Note that the default (``None``) assumes ``img`` to be already normalized.
        n_tiles : iterable or None
            Out of memory (OOM) errors can occur if the input image is too large.
            To avoid this problem, the input image is broken up into (overlapping) tiles
            that are processed independently and re-assembled.
            This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
            ``None`` denotes that no tiling should be used.
        show_tile_progress: bool
            Whether to show progress during tiled prediction.
        predict_kwargs: dict
            Keyword arguments for ``predict`` function of Keras model.

        Returns
        -------
        (:class:`numpy.ndarray`,:class:`numpy.ndarray`)
            Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances.

        """
        if n_tiles is None:
            n_tiles = [1] * img.ndim
        try:
            n_tiles = tuple(n_tiles)
            img.ndim == len(n_tiles) or _raise(TypeError())
        except TypeError:
            raise ValueError("n_tiles must be an iterable of length %d" %
                             img.ndim)
        all(np.isscalar(t) and 1 <= t and int(t) == t
            for t in n_tiles) or _raise(
                ValueError(
                    "all values of n_tiles must be integer values >= 1"))
        n_tiles = tuple(map(int, n_tiles))

        axes = self._normalize_axes(img, axes)
        axes_net = self.config.axes

        _permute_axes = self._make_permute_axes(axes, axes_net)
        x = _permute_axes(img)  # x has axes_net semantics

        channel = axes_dict(axes_net)['C']
        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
        axes_net_div_by = self._axes_div_by(axes_net)

        grid = tuple(self.config.grid)
        len(grid) == len(axes_net) - 1 or _raise(ValueError())
        grid_dict = dict(zip(axes_net.replace('C', ''), grid))

        normalizer = self._check_normalizer_resizer(normalizer, None)[0]
        resizer = StarDistPadAndCropResizer(grid=grid_dict)

        x = normalizer.before(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)

        def predict_direct(tile):
            sh = list(tile.shape)
            sh[channel] = 1
            dummy = np.empty(sh, np.float32)
            prob, dist = self.keras_model.predict(
                [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs)
            return prob[0], dist[0]

        if np.prod(n_tiles) > 1:
            tiling_axes = axes_net.replace('C', '')  # axes eligible for tiling
            x_tiling_axis = tuple(
                axes_dict(axes_net)[a]
                for a in tiling_axes)  # numerical axis ids for x
            axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
            # hack: permute tiling axis in the same way as img -> x was permuted
            n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape
            (all(n_tiles[i] == 1
                 for i in range(x.ndim) if i not in x_tiling_axis)
             or _raise(
                 ValueError("entry of n_tiles > 1 only allowed for axes '%s'" %
                            tiling_axes)))

            sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)]
            sh[channel] = 1
            prob = np.empty(sh, np.float32)
            sh[channel] = self.config.n_rays
            dist = np.empty(sh, np.float32)

            n_block_overlaps = [
                int(np.ceil(overlap / blocksize)) for overlap, blocksize in
                zip(axes_net_tile_overlaps, axes_net_div_by)
            ]

            for tile, s_src, s_dst in tqdm(tile_iterator(
                    x,
                    n_tiles,
                    block_sizes=axes_net_div_by,
                    n_block_overlaps=n_block_overlaps),
                                           disable=(not show_tile_progress),
                                           total=np.prod(n_tiles)):
                prob_tile, dist_tile = predict_direct(tile)
                # account for grid
                s_src = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_src, axes_net)
                ]
                s_dst = [
                    slice(s.start // grid_dict.get(a, 1),
                          s.stop // grid_dict.get(a, 1))
                    for s, a in zip(s_dst, axes_net)
                ]
                # prob and dist have different channel dimensionality than image x
                s_src[channel] = slice(None)
                s_dst[channel] = slice(None)
                s_src, s_dst = tuple(s_src), tuple(s_dst)
                # print(s_src,s_dst)
                prob[s_dst] = prob_tile[s_src]
                dist[s_dst] = dist_tile[s_src]

        else:
            prob, dist = predict_direct(x)

        prob = resizer.after(prob, axes_net)
        dist = resizer.after(dist, axes_net)
        dist = np.maximum(
            1e-3, dist
        )  # avoid small/negative dist values to prevent problems with Qhull

        prob = np.take(prob, 0, axis=channel)
        dist = np.moveaxis(dist, channel, -1)

        return prob, dist