def test_tile_iterator(guarantee, n_dims): rng = np.random.RandomState(42) for _ in range(10): n = rng.randint(low=10, high=300, size=n_dims) n_blocks = list(rng.randint(low=1, high=10, size=n_dims)) block_size = [_n // _n_blocks for _n_blocks, _n in zip(n_blocks, n)] n = [ _block_size * (_n // _block_size) for _block_size, _n in zip(block_size, n) ] n_block_overlap = [ rng.randint(low=0, high=_n_blocks + 1) for _n_blocks in n_blocks ] n_tiles = [ rng.randint(low=1, high=_n_blocks + 1) for _n_blocks in n_blocks ] x = rng.uniform(size=n) y = np.empty_like(x) c = 0 actual_n_tiles = total_n_tiles(x, n_tiles, block_size, n_block_overlap, guarantee=guarantee) for tile, s_src, s_dst in tile_iterator(x, n_tiles, block_size, n_block_overlap, guarantee): y[s_dst] = tile[s_src] c += 1 assert c == actual_n_tiles assert np.allclose(x, y)
def predict(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs): """Predict. Parameters ---------- img : :class:`numpy.ndarray` Input image axes : str or None Axes of the input ``img``. ``None`` denotes that axes of img are the same as denoted in the config. normalizer : :class:`csbdeep.data.Normalizer` or None (Optional) normalization of input image before prediction. Note that the default (``None``) assumes ``img`` to be already normalized. n_tiles : iterable or None Out of memory (OOM) errors can occur if the input image is too large. To avoid this problem, the input image is broken up into (overlapping) tiles that are processed independently and re-assembled. This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). ``None`` denotes that no tiling should be used. show_tile_progress: bool Whether to show progress during tiled prediction. predict_kwargs: dict Keyword arguments for ``predict`` function of Keras model. Returns ------- (:class:`numpy.ndarray`,:class:`numpy.ndarray`) Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances. """ if n_tiles is None: n_tiles = [1] * img.ndim try: n_tiles = tuple(n_tiles) img.ndim == len(n_tiles) or _raise(TypeError()) except TypeError: raise ValueError("n_tiles must be an iterable of length %d" % img.ndim) all(np.isscalar(t) and 1 <= t and int(t) == t for t in n_tiles) or _raise( ValueError( "all values of n_tiles must be integer values >= 1")) n_tiles = tuple(map(int, n_tiles)) axes = self._normalize_axes(img, axes) axes_net = self.config.axes _permute_axes = self._make_permute_axes(axes, axes_net) x = _permute_axes(img) # x has axes_net semantics channel = axes_dict(axes_net)['C'] self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) axes_net_div_by = self._axes_div_by(axes_net) grid = tuple(self.config.grid) len(grid) == len(axes_net) - 1 or _raise(ValueError()) grid_dict = dict(zip(axes_net.replace('C', ''), grid)) normalizer = self._check_normalizer_resizer(normalizer, None)[0] resizer = StarDistPadAndCropResizer(grid=grid_dict) x = normalizer.before(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) def predict_direct(tile): sh = list(tile.shape) sh[channel] = 1 dummy = np.empty(sh, np.float32) prob, dist = self.keras_model.predict( [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs) return prob[0], dist[0] if np.prod(n_tiles) > 1: tiling_axes = axes_net.replace('C', '') # axes eligible for tiling x_tiling_axis = tuple( axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x axes_net_tile_overlaps = self._axes_tile_overlap(axes_net) # hack: permute tiling axis in the same way as img -> x was permuted n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape (all(n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or _raise( ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes))) sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)] sh[channel] = 1 prob = np.empty(sh, np.float32) sh[channel] = self.config.n_rays dist = np.empty(sh, np.float32) n_block_overlaps = [ int(np.ceil(overlap / blocksize)) for overlap, blocksize in zip(axes_net_tile_overlaps, axes_net_div_by) ] for tile, s_src, s_dst in tqdm(tile_iterator( x, n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps), disable=(not show_tile_progress), total=np.prod(n_tiles)): prob_tile, dist_tile = predict_direct(tile) # account for grid s_src = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_src, axes_net) ] s_dst = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_dst, axes_net) ] # prob and dist have different channel dimensionality than image x s_src[channel] = slice(None) s_dst[channel] = slice(None) s_src, s_dst = tuple(s_src), tuple(s_dst) # print(s_src,s_dst) prob[s_dst] = prob_tile[s_src] dist[s_dst] = dist_tile[s_src] else: prob, dist = predict_direct(x) prob = resizer.after(prob, axes_net) dist = resizer.after(dist, axes_net) dist = np.maximum( 1e-3, dist ) # avoid small/negative dist values to prevent problems with Qhull prob = np.take(prob, 0, axis=channel) dist = np.moveaxis(dist, channel, -1) return prob, dist