def load_imgs(self, files, dims='YX'): """ Helper to read a list of files. The images are not required to have same size, but have to be of same dimensionality. Parameters ---------- files : list(String) List of paths to tiff-files. dims : String, optional(default='YX') Dimensions of the images to read. Known dimensions are: 'TZYXC' Returns ------- images : list(array(float)) A list of the read tif-files. The images have dimensionality 'SZYXC' or 'SYXC' """ assert 'Y' in dims and 'X' in dims, "'dims' has to contain 'X' and 'Y'." tmp_dims = dims for b in ['X', 'Y', 'Z', 'T', 'C']: assert tmp_dims.count(b) <= 1, "'dims' has to contain {} at most once.".format(b) tmp_dims = tmp_dims.replace(b, '') assert len(tmp_dims) == 0, "Unknown dimensions in 'dims'." if 'Z' in dims: net_axes = 'ZYXC' else: net_axes = 'YXC' move_axis_from = () move_axis_to = () for d, b in enumerate(dims): move_axis_from += tuple([d]) if b == 'T': move_axis_to += tuple([0]) elif b == 'C': move_axis_to += tuple([-1]) elif b in 'XYZ': if 'T' in dims: move_axis_to += tuple([net_axes.index(b)+1]) else: move_axis_to += tuple([net_axes.index(b)]) imgs = [] for f in files: if f.endswith('.tif') or f.endswith('.tiff'): imread = tifffile.imread elif f.endswith('.png'): imread = image.imread elif f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.JPEG') or f.endswith('.JPG'): _raise(Exception("JPEG is not supported, because it is not loss-less and breaks the pixel-wise independence assumption.")) else: _raise("Filetype '{}' is not supported.".format(f)) img = imread(f).astype(np.float32) assert len(img.shape) == len(dims), "Number of image dimensions doesn't match 'dims'." img = np.moveaxis(img, move_axis_from, move_axis_to) if not ('T' in dims): img = img[np.newaxis] if not ('C' in dims): img = img[..., np.newaxis] imgs.append(img) return imgs
def predict(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs): """Predict. Parameters ---------- img : :class:`numpy.ndarray` Input image axes : str or None Axes of the input ``img``. ``None`` denotes that axes of img are the same as denoted in the config. normalizer : :class:`csbdeep.data.Normalizer` or None (Optional) normalization of input image before prediction. Note that the default (``None``) assumes ``img`` to be already normalized. n_tiles : iterable or None Out of memory (OOM) errors can occur if the input image is too large. To avoid this problem, the input image is broken up into (overlapping) tiles that are processed independently and re-assembled. This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). ``None`` denotes that no tiling should be used. show_tile_progress: bool Whether to show progress during tiled prediction. predict_kwargs: dict Keyword arguments for ``predict`` function of Keras model. Returns ------- (:class:`numpy.ndarray`,:class:`numpy.ndarray`) Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances. """ if n_tiles is None: n_tiles = [1] * img.ndim try: n_tiles = tuple(n_tiles) img.ndim == len(n_tiles) or _raise(TypeError()) except TypeError: raise ValueError("n_tiles must be an iterable of length %d" % img.ndim) all(np.isscalar(t) and 1 <= t and int(t) == t for t in n_tiles) or _raise( ValueError( "all values of n_tiles must be integer values >= 1")) n_tiles = tuple(map(int, n_tiles)) axes = self._normalize_axes(img, axes) axes_net = self.config.axes _permute_axes = self._make_permute_axes(axes, axes_net) x = _permute_axes(img) # x has axes_net semantics channel = axes_dict(axes_net)['C'] self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) axes_net_div_by = self._axes_div_by(axes_net) grid = tuple(self.config.grid) len(grid) == len(axes_net) - 1 or _raise(ValueError()) grid_dict = dict(zip(axes_net.replace('C', ''), grid)) normalizer = self._check_normalizer_resizer(normalizer, None)[0] resizer = StarDistPadAndCropResizer(grid=grid_dict) x = normalizer.before(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) def predict_direct(tile): sh = list(tile.shape) sh[channel] = 1 dummy = np.empty(sh, np.float32) prob, dist = self.keras_model.predict( [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs) return prob[0], dist[0] if np.prod(n_tiles) > 1: tiling_axes = axes_net.replace('C', '') # axes eligible for tiling x_tiling_axis = tuple( axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x axes_net_tile_overlaps = self._axes_tile_overlap(axes_net) # hack: permute tiling axis in the same way as img -> x was permuted n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape (all(n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or _raise( ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes))) sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)] sh[channel] = 1 prob = np.empty(sh, np.float32) sh[channel] = self.config.n_rays dist = np.empty(sh, np.float32) n_block_overlaps = [ int(np.ceil(overlap / blocksize)) for overlap, blocksize in zip(axes_net_tile_overlaps, axes_net_div_by) ] for tile, s_src, s_dst in tqdm(tile_iterator( x, n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps), disable=(not show_tile_progress), total=np.prod(n_tiles)): prob_tile, dist_tile = predict_direct(tile) # account for grid s_src = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_src, axes_net) ] s_dst = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_dst, axes_net) ] # prob and dist have different channel dimensionality than image x s_src[channel] = slice(None) s_dst[channel] = slice(None) s_src, s_dst = tuple(s_src), tuple(s_dst) # print(s_src,s_dst) prob[s_dst] = prob_tile[s_src] dist[s_dst] = dist_tile[s_src] else: prob, dist = predict_direct(x) prob = resizer.after(prob, axes_net) dist = resizer.after(dist, axes_net) dist = np.maximum( 1e-3, dist ) # avoid small/negative dist values to prevent problems with Qhull prob = np.take(prob, 0, axis=channel) dist = np.moveaxis(dist, channel, -1) return prob, dist
def __init__(self, X, Y, n_rays, grid, batch_size, patch_size, use_gpu=False, maxfilter_cache=True, maxfilter_patch_size=None, augmenter=None): X = [x.astype(np.float32, copy=False) for x in X] # Y = [y.astype(np.uint16, copy=False) for y in Y] # sanity checks assert len(X) == len(Y) and len(X) > 0 nD = len(patch_size) assert nD in (2, 3) x_ndim = X[0].ndim assert x_ndim in (nD, nD + 1) assert all( y.ndim == nD and x.ndim == x_ndim and x.shape[:nD] == y.shape for x, y in zip(X, Y)) if x_ndim == nD: self.n_channel = None else: self.n_channel = X[0].shape[-1] assert all(x.shape[-1] == self.n_channel for x in X) self.X, self.Y = X, Y self.batch_size = batch_size self.n_rays = n_rays self.patch_size = patch_size self.ss_grid = (slice(None), ) + tuple(slice(0, None, g) for g in grid) self.perm = np.random.permutation(len(self.X)) self.use_gpu = bool(use_gpu) if augmenter is None: augmenter = lambda *args: args callable(augmenter) or _raise( ValueError("augmenter must be None or callable")) self.augmenter = augmenter if self.use_gpu: from gputools import max_filter self.max_filter = lambda y, patch_size: max_filter( y.astype(np.float32), patch_size) else: from scipy.ndimage.filters import maximum_filter self.max_filter = lambda y, patch_size: maximum_filter( y, patch_size, mode='constant') self.maxfilter_patch_size = maxfilter_patch_size if maxfilter_patch_size is not None else self.patch_size if maxfilter_cache: self.R = [ self.no_background_patches((y, x)) for x, y in zip(self.X, self.Y) ] else: self.R = None
def __init__(self, X, **kwargs): # X is empty if config is None if (X.size != 0): assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 n_channel_in = X.shape[-1] n_channel_out = n_channel_in means, stds = [], [] for i in range(n_channel_in): means.append(np.mean(X[...,i])) stds.append(np.std(X[...,i])) if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S','') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C'+axes # normalization parameters self.means = [str(el) for el in means] self.stds = [str(el) for el in stds] # directly set by parameters self.n_dim = n_dim self.axes = axes self.n_channel_in = int(n_channel_in) self.n_channel_out = int(n_channel_out) # default config (can be overwritten by kwargs below) self.unet_residual = False self.unet_n_depth = 2 self.unet_kern_size = 5 if self.n_dim==2 else 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' if backend_channels_last(): self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,) self.train_loss = 'mae' self.train_epochs = 100 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 16 self.train_tensorboard = True self.train_checkpoint = 'weights_best.h5' self.train_reduce_lr = {'factor': 0.5, 'patience': 10} self.batch_norm = True self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 # disallow setting 'n_dim' manually try: del kwargs['n_dim'] # warnings.warn("ignoring parameter 'n_dim'") except: pass self.probabilistic = False for k in kwargs: setattr(self, k, kwargs[k])
def _cpp_star_dist(a, n_rays=32): (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError()) return c_star_dist(a.astype(np.uint16, copy=False), int(n_rays))
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None, numGPU=1): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. Y : :class:`numpy.ndarray` Array of target images. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of arrays for source and target validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ if numGPU > 1: # if more than 1 gpu is requested, use multimodel print('Using multiple GPUs for training') self.keras_model = multi_gpu_model(self.keras_model, gpus=numGPU, cpu_merge=True, cpu_relocation=False) ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) n_train, n_val = len(X), len(validation_data[0]) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.1f%% of all images)" % (100 * frac_val)) axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) for a, div_by in zip(axes, self._axes_div_by(axes)): n = X.shape[ax[a]] if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axis %s" " (which has incompatible size %d)" % (div_by, a, n)) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() training_data = CryoDataWrapper(X, Y, self.config.train_batch_size, self.config.n_dim) history = self.keras_model.fit_generator( generator=training_data, validation_data=validation_data, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False): expected_keys = set( ('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score')) single_thresh = False if np.isscalar(thresh): single_thresh = True thresh = (thresh, ) tqdm_kwargs = {} tqdm_kwargs['disable'] = not bool(show_progress) if int(show_progress) > 1: tqdm_kwargs['total'] = int(show_progress) # compute matching stats for every pair of label images if parallel: from concurrent.futures import ThreadPoolExecutor fn = lambda pair: matching( *pair, thresh=thresh, criterion=criterion, report_matches=False) with ThreadPoolExecutor() as pool: stats_all = tuple(pool.map(fn, tqdm(y_gen, **tqdm_kwargs))) else: stats_all = tuple( matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False) for y_t, y_p in tqdm(y_gen, **tqdm_kwargs)) # accumulate results over all images for each threshold separately n_images, n_threshs = len(stats_all), len(thresh) accumulate = [{} for _ in range(n_threshs)] for stats in stats_all: for i, s in enumerate(stats): acc = accumulate[i] for k, v in s._asdict().items(): if k == 'mean_true_score' and not bool(by_image): # convert mean_true_score to "sum_true_score" acc[k] = acc.setdefault(k, 0) + v * s.n_true else: try: acc[k] = acc.setdefault(k, 0) + v except TypeError: pass # normalize/compute 'precision', 'recall', 'accuracy', 'f1' for thr, acc in zip(thresh, accumulate): set(acc.keys()) == expected_keys or _raise( ValueError("unexpected keys")) acc['criterion'] = criterion acc['thresh'] = thr acc['by_image'] = bool(by_image) if bool(by_image): for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score'): acc[k] /= n_images else: tp, fp, fn = acc['tp'], acc['fp'], acc['fn'] acc.update( precision=precision(tp, fp, fn), recall=recall(tp, fp, fn), accuracy=accuracy(tp, fp, fn), f1=f1(tp, fp, fn), mean_true_score=acc['mean_true_score'] / acc['n_true'] if acc['n_true'] > 0 else 0.0, ) accumulate = tuple( namedtuple('DatasetMatching', acc.keys())(*acc.values()) for acc in accumulate) return accumulate[0] if single_thresh else accumulate
def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False): """ if report_matches=True, return (matched_pairs,matched_scores) are independent of 'thresh' """ _check_label_array(y_true, 'y_true') _check_label_array(y_pred, 'y_pred') y_true.shape == y_pred.shape or _raise( ValueError( "y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes" .format(y_true=y_true, y_pred=y_pred))) criterion in matching_criteria or _raise( ValueError("Matching criterion '%s' not supported." % criterion)) if thresh is None: thresh = 0 thresh = float(thresh) if np.isscalar(thresh) else map(float, thresh) y_true, _, map_rev_true = relabel_sequential(y_true) y_pred, _, map_rev_pred = relabel_sequential(y_pred) overlap = label_overlap(y_true, y_pred, check=False) scores = matching_criteria[criterion](overlap) assert 0 <= np.min(scores) <= np.max(scores) <= 1 # ignoring background scores = scores[1:, 1:] n_true, n_pred = scores.shape n_matched = min(n_true, n_pred) def _single(thr): not_trivial = n_matched > 0 and np.any(scores >= thr) if not_trivial: # compute optimal matching with scores as tie-breaker costs = -(scores >= thr).astype(float) - scores / (2 * n_matched) true_ind, pred_ind = linear_sum_assignment(costs) assert n_matched == len(true_ind) == len(pred_ind) match_ok = scores[true_ind, pred_ind] >= thr tp = np.count_nonzero(match_ok) else: tp = 0 fp = n_pred - tp fn = n_true - tp # assert tp+fp == n_pred # assert tp+fn == n_true stats_dict = dict( criterion=criterion, thresh=thr, fp=fp, tp=tp, fn=fn, precision=precision(tp, fp, fn), recall=recall(tp, fp, fn), accuracy=accuracy(tp, fp, fn), f1=f1(tp, fp, fn), n_true=n_true, n_pred=n_pred, mean_true_score=np.sum(scores[true_ind, pred_ind][match_ok]) / n_true if not_trivial else 0.0, ) if bool(report_matches): if not_trivial: stats_dict.update( # int() to be json serializable matched_pairs=tuple( (int(map_rev_true[i]), int(map_rev_pred[j])) for i, j in zip(1 + true_ind, 1 + pred_ind)), matched_scores=tuple(scores[true_ind, pred_ind]), matched_tps=tuple(map(int, np.flatnonzero(match_ok))), ) else: stats_dict.update( matched_pairs=(), matched_scores=(), matched_tps=(), ) return namedtuple('Matching', stats_dict.keys())(*stats_dict.values()) return _single(thresh) if np.isscalar(thresh) else tuple( map(_single, thresh))