def _axes_div_by(self, query_axes): if self.config.backbone == "unet": query_axes = axes_check_and_normalize(query_axes) assert len(self.config.unet_pool) == len(self.config.grid) div_by = dict( zip( self.config.axes.replace('C', ''), tuple(p**self.config.unet_n_depth * g for p, g in zip( self.config.unet_pool, self.config.grid)))) return tuple(div_by.get(a, 1) for a in query_axes) elif self.config.backbone == "resnet": grid_dict = dict( zip(self.config.axes.replace('C', ''), self.config.grid)) return tuple(grid_dict.get(a, 1) for a in query_axes) else: raise NotImplementedError()
def _axes_change(value: str): if value != value.upper(): with plugin.axes.changed.blocked(): plugin.axes.value = value.upper() image = plugin.image.value axes = "" try: image is not None or _raise(ValueError("no image selected")) axes = axes_check_and_normalize(value, length=get_data(image).ndim, disallowed="S") update("image_axes", True, (axes, image, None)) except ValueError as err: update("image_axes", False, (value, image, err)) finally: widgets_inactive(plugin.timelapse_opts, active=("T" in axes))
def after(self, x, axes): # axes can include 'C', which may not have been present in before() axes = axes_check_and_normalize(axes, x.ndim) assert all(s_pad == s * g for s, s_pad, g in zip(x.shape, ( self.padded_shape.get(a, _s) for a, _s in zip(axes, x.shape)), (self.grid.get(a, 1) for a in axes))) # print(self.padded_shape) # print(self.pad) # print(self.grid) crop = tuple( slice(0, -(math.floor(p[1] / g)) if p[1] >= g else None) for p, g in zip((self.pad.get(a, (0, 0)) for a in axes), (self.grid.get(a, 1) for a in axes))) # print(crop) return x[crop]
def before(self, x, axes, axes_div_by): assert all(a % g == 0 for g, a in zip((self.grid.get(a, 1) for a in axes), axes_div_by)) axes = axes_check_and_normalize(axes, x.ndim) def _split(v): return 0, v # only pad at the end self.pad = { a: _split((div_n - s % div_n) % div_n) for a, div_n, s in zip(axes, axes_div_by, x.shape) } x_pad = np.pad(x, tuple(self.pad[a] for a in axes), mode=self.mode, **self.kwargs) self.padded_shape = dict(zip(axes, x_pad.shape)) if 'C' in self.padded_shape: del self.padded_shape['C'] return x_pad
def cover(shape, axes, block_size, min_overlap, context, grid=1): """Return grid-aligned n-dimensional blocks to cover region of the given shape with axes semantics. Parameters block_size, min_overlap, and context can be different per dimension/axis (if provided as list) or the same (if provided as scalar value). Also see `Block.cover`. """ shape = tuple(shape) n = len(shape) axes = axes_check_and_normalize(axes, length=n) if np.isscalar(block_size): block_size = n*[block_size] if np.isscalar(min_overlap): min_overlap = n*[min_overlap] if np.isscalar(context): context = n*[context] if np.isscalar(grid): grid = n*[grid] assert n == len(block_size) == len(min_overlap) == len(context) == len(grid) # compute cover for each dimension cover_1d = [Block.cover(*args) for args in zip(shape, block_size, min_overlap, context, grid)] # return cover as Cartesian product of 1-dimensional blocks return tuple(BlockND(i,blocks,axes) for i,blocks in enumerate(product(*cover_1d)))
def __init__(self, X, **kwargs): """See class docstring.""" assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 n_channel_in = X.shape[-1] n_channel_out = n_channel_in mean = np.mean(X) std = np.std(X) if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S','') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C'+axes # normalization parameters self.mean = str(mean) self.std = str(std) # directly set by parameters self.n_dim = n_dim self.axes = axes self.n_channel_in = int(n_channel_in) self.n_channel_out = int(n_channel_out) # default config (can be overwritten by kwargs below) self.unet_residual = False self.unet_n_depth = 2 self.unet_kern_size = 5 if self.n_dim==2 else 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' if backend_channels_last(): self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,) self.train_loss = 'mae' self.train_epochs = 100 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 16 self.train_tensorboard = True self.train_checkpoint = 'weights_best.h5' self.train_reduce_lr = {'factor': 0.5, 'patience': 10} self.batch_norm = True self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 # disallow setting 'n_dim' manually try: del kwargs['n_dim'] # warnings.warn("ignoring parameter 'n_dim'") except: pass for k in kwargs: setattr(self, k, kwargs[k])
def blocks_for_axes(self, axes=None): axes = self.axes if axes is None else axes_check_and_normalize(axes) return tuple(self.axis_to_block[a] for a in axes)
def __init__(self, id, blocks, axes): self.id = id self.blocks = tuple(blocks) self.axes = axes_check_and_normalize(axes, length=len(self.blocks)) self.axis_to_block = dict(zip(self.axes, self.blocks))
def predict_instances_big(self, img, axes, block_size, min_overlap, context=None, labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs): """Predict instance segmentation from very large input images. Intended to be used when `predict_instances` cannot be used due to memory limitations. This function will break the input image into blocks and process them individually via `predict_instances` and assemble all the partial results. If used as intended, the result should be the same as if `predict_instances` was used directly on the whole image. **Important**: The crucial assumption is that all predicted object instances are smaller than the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size. Example ------- >>> img.shape (20000, 20000) >>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096, min_overlap=128, context=128, n_tiles=(4,4)) Parameters ---------- img: :class:`numpy.ndarray` or similar Input image axes: str Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.) block_size: int or iterable of int Process input image in blocks of the provided shape. (If a scalar value is given, it is used for all spatial image dimensions.) min_overlap: int or iterable of int Amount of guaranteed overlap between blocks. (If a scalar value is given, it is used for all spatial image dimensions.) context: int or iterable of int, or None Amount of image context on all sides of a block, which is discarded. If None, uses an automatic estimate that should work in many cases. (If a scalar value is given, it is used for all spatial image dimensions.) labels_out: :class:`numpy.ndarray` or similar, or None, or False numpy array or similar (must be of correct shape) to which the label image is written. If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``. If False, will not write the label image (useful if only the dictionary is needed). labels_out_dtype: str or dtype Data type of returned label image if ``labels_out=None`` (has no effect otherwise). show_progress: bool Show progress bar for block processing. kwargs: dict Keyword arguments for ``predict_instances``. Returns ------- (:class:`numpy.ndarray` or False, dict) Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra. """ from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels from ..matching import relabel_sequential n = img.ndim axes = axes_check_and_normalize(axes, length=n) grid = self._axes_div_by(axes) axes_out = self._axes_out.replace('C','') shape_dict = dict(zip(axes,img.shape)) shape_out = tuple(shape_dict[a] for a in axes_out) if context is None: context = self._axes_tile_overlap(axes) if np.isscalar(block_size): block_size = n*[block_size] if np.isscalar(min_overlap): min_overlap = n*[min_overlap] if np.isscalar(context): context = n*[context] block_size, min_overlap, context = list(block_size), list(min_overlap), list(context) assert n == len(block_size) == len(min_overlap) == len(context) if 'C' in axes: # single block for channel axis i = axes_dict(axes)['C'] # if (block_size[i], min_overlap[i], context[i]) != (None, None, None): # print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " + # "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True) block_size[i] = img.shape[i] min_overlap[i] = context[i] = 0 block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes)) min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes)) context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes)) # print(f"input: shape {img.shape} with axes {axes}") print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True) for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)): if c < o: print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True) # create block cover blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid) if np.isscalar(labels_out) and bool(labels_out) is False: labels_out = None else: if labels_out is None: labels_out = np.zeros(shape_out, dtype=labels_out_dtype) else: labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out}).")) polys_all = {} # problem_ids = [] label_offset = 1 kwargs_override = dict(axes=axes, overlap_label=None) if show_progress: kwargs_override['show_tile_progress'] = False # disable progress for predict_instances for k,v in kwargs_override.items(): if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True) kwargs[k] = v blocks = tqdm(blocks, disable=(not show_progress)) # actual computation for block in blocks: labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs) labels = block.crop_context(labels, axes=axes_out) labels, polys = block.filter_objects(labels, polys, axes=axes_out) # TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset) labels = relabel_sequential(labels, label_offset)[0] # labels, fwd_map, _ = relabel_sequential(labels, label_offset) # if len(incomplete) > 0: # problem_ids.extend([fwd_map[i] for i in incomplete]) # if show_progress: # blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}") if labels_out is not None: block.write(labels_out, labels, axes=axes_out) for k,v in polys.items(): polys_all.setdefault(k,[]).append(v) label_offset += len(polys['prob']) polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()} # if labels_out is not None and len(problem_ids) > 0: # # if show_progress: # # blocks.write('') # # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True) # repaint_labels(labels_out, problem_ids, polys_all, show_progress=False) return labels_out, polys_all#, tuple(problem_ids)
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. Y : :class:`numpy.ndarray` Array of target images. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of arrays for source and target validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) n_train, n_val = len(X), len(validation_data[0]) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.1f%% of all images)" % (100 * frac_val)) axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) for a, div_by in zip(axes, self._axes_div_by(axes)): n = X.shape[ax[a]] if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axis %s" " (which has incompatible size %d)" % (div_by, a, n)) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() training_data = CryoDataWrapper(X, Y, self.config.train_batch_size) history = self.keras_model.fit_generator( generator=training_data, validation_data=validation_data, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def train(self, X, validation_X, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. validation_x : :class:`numpy.ndarray` Array of validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ n_train, n_val = len(X), len(validation_X) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn("small number of validation images (only %.1f%% of all images)" % (100*frac_val)) axes = axes_check_and_normalize('S'+self.config.axes,X.ndim) ax = axes_dict(axes) div_by = 2**self.config.unet_n_depth axes_relevant = ''.join(a for a in 'XYZT' if a in axes) val_num_pix = 1 train_num_pix = 1 val_patch_shape = () for a in axes_relevant: n = X.shape[ax[a]] val_num_pix *= validation_X.shape[ax[a]] train_num_pix *= X.shape[ax[a]] val_patch_shape += tuple([validation_X.shape[ax[a]]]) if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axes %s" " (axis %s has incompatible size %d)" % (div_by,axes_relevant,a,n) ) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() manipulator = eval('pm_{0}({1})'.format(self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius))) mean, std = float(self.config.mean), float(self.config.std) X = self.__normalize__(X, mean, std) validation_X = self.__normalize__(validation_X, mean, std) # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating. training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')), self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix), self.config.n2v_patch_shape, manipulator) # validation_Y is also validation_X plus a concatinated masking channel. # To speed things up, we precomupte the masking vo the validation data. validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C')) n2v_utils.manipulate_val_data(validation_X, validation_Y, num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix), shape=val_patch_shape, value_manipulation=manipulator) history = self.keras_model.fit_generator(generator=training_data, validation_data=(validation_X, validation_Y), epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def load_training_data_direct(X, Y, validation_split=0, axes=None, n_images=None, verbose=False): """Load training data from file in ``.npz`` format. The data file is expected to have the keys: - ``X`` : Array of training input images. - ``Y`` : Array of corresponding target images. - ``axes`` : Axes of the training images. Parameters ---------- file : str File name validation_split : float Fraction of images to use as validation set during training. axes: str, optional Must be provided in case the loaded data does not contain ``axes`` information. n_images : int, optional Can be used to limit the number of images loaded from data. verbose : bool, optional Can be used to display information about the loaded images. Returns ------- tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str ) Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets and the axes of the input images. The tuple of validation data will be ``None`` if ``validation_split = 0``. """ # f = np.load(file) # X, Y = f['X'], f['Y'] # if axes is None: # axes = f['axes'] axes = axes_check_and_normalize(axes) assert X.shape == Y.shape assert len(axes) == X.ndim assert 'C' in axes if n_images is None: n_images = X.shape[0] assert X.shape[0] == Y.shape[0] assert 0 < n_images <= X.shape[0] assert 0 <= validation_split < 1 X, Y = X[:n_images], Y[:n_images] channel = axes_dict(axes)['C'] if validation_split > 0: n_val = int(round(n_images * validation_split)) n_train = n_images - n_val assert 0 < n_val and 0 < n_train X_t, Y_t = X[-n_val:], Y[-n_val:] X, Y = X[:n_train], Y[:n_train] assert X.shape[0] == n_train and X_t.shape[0] == n_val X_t = move_channel_for_backend(X_t, channel=channel) Y_t = move_channel_for_backend(Y_t, channel=channel) X = move_channel_for_backend(X, channel=channel) Y = move_channel_for_backend(Y, channel=channel) axes = axes.replace('C', '') # remove channel if backend_channels_last(): axes = axes + 'C' else: axes = axes[:1] + 'C' + axes[1:] data_val = (X_t, Y_t) if validation_split > 0 else None if verbose: ax = axes_dict(axes) n_train, n_val = len(X), len(X_t) if validation_split > 0 else 0 image_size = tuple(X.shape[ax[a]] for a in 'TZYX' if a in axes) n_dim = len(image_size) n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']] print('number of training images:\t', n_train) print('number of validation images:\t', n_val) print('image size (%dD):\t\t' % n_dim, image_size) print('axes:\t\t\t\t', axes) print('channels in / out:\t\t', n_channel_in, '/', n_channel_out) return (X, Y), data_val, axes
def __init__(self, X,**kwargs): if X.size != 0: assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes means, stds = [], [] for i in range(X.shape[-1]): means.append(np.mean(X[..., i])) stds.append(np.std(X[..., i])) # normalization parameters self.means = [str(el) for el in means] self.stds = [str(el) for el in stds] # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters if 'C' in axes: self.n_channel_in = X.shape[-1] else: self.n_channel_in = 1 self.train_loss = 'demix' # default config (can be overwritten by kwargs below) self.unet_n_depth = 2 self.unet_kern_size = 3 self.unet_n_first = 64 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,) # fixed parameters self.train_epochs = 200 self.train_steps_per_epoch = 50 self.train_learning_rate = 0.0004 self.train_batch_size = 64 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10} self.batch_norm = False self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (128, 128) if self.n_dim == 2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 self.single_net_per_channel = False self.channel_denoised = False self.multi_objective = True self.normalizer = 'none' self.weights_objectives = [0.1, 0, 0.45, 0.45] self.distributions = 'gauss' self.n2v_leave_center = False self.scale_aug = False self.structN2Vmask = None self.n_back_modes = 2 self.n_fore_modes = 2 self.n_instance_seg = 0 self.n_back_i_modes = 1 self.n_fore_i_modes = 1 self.fit_std = False self.fit_mean = True # self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components +\ # self.n_instance_seg * (self.n_back_i_modes+self.n_fore_i_modes) try: kwargs['probabilistic'] = False except: pass # disallow setting 'unet_residual' manually try: kwargs['unet_residual'] = False except: pass # print('KWARGS') for k in kwargs: # print(k, kwargs[k]) setattr(self, k, kwargs[k]) self.n_components = 3 if (self.fit_std & self.fit_mean) else 2 self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components + \ self.n_instance_seg * (self.n_back_i_modes + self.n_fore_i_modes)
def __init__(self, X, **kwargs): """See class docstring""" assert X.size != 0 assert len(X.shape) == 4 or len( X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise( ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise( ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise( ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise( ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise( ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes assert X.shape[-1] == 1 n_channel_in = 1 n_channel_out = 3 # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters self.n_channel_in = n_channel_in self.n_channel_out = n_channel_out self.train_loss = 'seg' # default config (can be overwritten by kwargs below) self.unet_n_depth = 4 self.relative_weights = [1.0, 1.0, 5.0] self.unet_kern_size = 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None, ) + ( self.n_channel_in, ) else: self.unet_input_shape = ( self.n_channel_in, ) + self.n_dim * (None, ) self.train_epochs = 200 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 128 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'factor': 0.5, 'patience': 10} self.batch_norm = True # disallow setting 'n_dim' manually try: del kwargs['n_dim'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'n_channel_in' manually try: del kwargs['n_channel_in'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'n_channel_out' manually try: del kwargs['n_channel_out'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'train_loss' manually try: del kwargs['train_loss'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'probabilistic' manually try: del kwargs['probabilistic'] except: pass # disallow setting 'unet_residual' manually try: del kwargs['unet_residual'] except: pass for k in kwargs: setattr(self, k, kwargs[k])
def plugin( viewer: napari.Viewer, label_head, image: napari.layers.Image, axes, label_nn, model_type, model2d, model3d, model_folder, model_axes, norm_image, perc_low, perc_high, input_scale, label_nms, prob_thresh, nms_thresh, output_type, label_adv, n_tiles, norm_axes, timelapse_opts, cnn_output, set_thresholds, defaults_button, progress_bar: mw.ProgressBar, ) -> List[napari.types.LayerDataTuple]: model = get_model(*model_selected) if model._is_multiclass(): warn( "multi-class mode not supported yet, ignoring classification output" ) lkwargs = {} x = get_data(image) axes = axes_check_and_normalize(axes, length=x.ndim) if not (input_scale is None or isinstance(input_scale, numbers.Number)): input_scale = tuple(s for a, s in zip(axes, input_scale) if a not in ("T", )) # print(f'scaling by {input_scale}') if not axes.replace("T", "").startswith( model._axes_out.replace("C", "")): warn( f"output images have different axes ({model._axes_out.replace('C','')}) than input image ({axes})" ) # TODO: adjust image.scale according to shuffled axes if norm_image: axes_norm = axes_check_and_normalize(norm_axes) axes_norm = "".join(set(axes_norm).intersection( set(axes))) # relevant axes present in input image assert len(axes_norm) > 0 # always jointly normalize channels for RGB images if ("C" in axes and image.rgb == True) and ("C" not in axes_norm): axes_norm = axes_norm + "C" warn("jointly normalizing channels of RGB input image") ax = axes_dict(axes) _axis = tuple(sorted(ax[a] for a in axes_norm)) # # TODO: address joint vs. channel/time-separate normalization properly (let user choose) # # also needs to be documented somewhere # if 'T' in axes: # if 'C' not in axes or image.rgb == True: # # normalize channels jointly, frames independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],)) # else: # # normalize channels independently, frames independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],ax['C'])) # else: # if 'C' not in axes or image.rgb == True: # # normalize channels jointly # _axis = None # else: # # normalize channels independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],)) x = normalize(x, perc_low, perc_high, axis=_axis) # TODO: progress bar (labels) often don't show up. events not processed? if "T" in axes: app = use_app() t = axes_dict(axes)["T"] n_frames = x.shape[t] if n_tiles is not None: # remove tiling value for time axis n_tiles = tuple(v for i, v in enumerate(n_tiles) if i != t) def progress(it, **kwargs): progress_bar.label = "StarDist Prediction (frames)" progress_bar.range = (0, n_frames) progress_bar.value = 0 progress_bar.show() app.process_events() for item in it: yield item progress_bar.increment() app.process_events() app.process_events() elif n_tiles is not None and np.prod(n_tiles) > 1: n_tiles = tuple(n_tiles) app = use_app() def progress(it, **kwargs): progress_bar.label = "CNN Prediction (tiles)" progress_bar.range = (0, kwargs.get("total", 0)) progress_bar.value = 0 progress_bar.show() app.process_events() for item in it: yield item progress_bar.increment() app.process_events() # progress_bar.label = "NMS Postprocessing" progress_bar.range = (0, 0) app.process_events() else: progress = False progress_bar.label = "StarDist Prediction" progress_bar.range = (0, 0) progress_bar.show() use_app().process_events() # semantic output axes of predictions assert model._axes_out[-1] == "C" axes_out = list(model._axes_out[:-1]) if "T" in axes: x_reorder = np.moveaxis(x, t, 0) axes_reorder = axes.replace("T", "") axes_out.insert(t, "T") res = tuple( zip(*tuple( model.predict_instances( _x, axes=axes_reorder, prob_thresh=prob_thresh, nms_thresh=nms_thresh, n_tiles=n_tiles, scale=input_scale, sparse=(not cnn_output), return_predict=cnn_output, ) for _x in progress(x_reorder)))) if cnn_output: labels, polys = tuple(zip(*res[0])) cnn_output = tuple(np.stack(c, t) for c in tuple(zip(*res[1]))) else: labels, polys = res labels = np.asarray(labels) if len(polys) > 1: if timelapse_opts == TimelapseLabels.Match.value: # match labels in consecutive frames (-> simple IoU tracking) labels = group_matching_labels(labels) elif timelapse_opts == TimelapseLabels.Unique.value: # make label ids unique (shift by offset) offsets = np.cumsum([len(p["points"]) for p in polys]) for y, off in zip(labels[1:], offsets): y[y > 0] += off elif timelapse_opts == TimelapseLabels.Separate.value: # each frame processed separately (nothing to do) pass else: raise NotImplementedError( f"unknown option '{timelapse_opts}' for time-lapse labels" ) labels = np.moveaxis(labels, 0, t) if isinstance(model, StarDist3D): # TODO poly output support for 3D timelapse polys = None else: polys = dict( coord=np.concatenate( tuple( np.insert(p["coord"], t, _t, axis=-2) for _t, p in enumerate(polys)), axis=0, ), points=np.concatenate( tuple( np.insert(p["points"], t, _t, axis=-1) for _t, p in enumerate(polys)), axis=0, ), ) if cnn_output: pred = (labels, polys), cnn_output else: pred = labels, polys else: # TODO: possible to run this in a way that it can be canceled? pred = model.predict_instances( x, axes=axes, prob_thresh=prob_thresh, nms_thresh=nms_thresh, n_tiles=n_tiles, show_tile_progress=progress, scale=input_scale, sparse=(not cnn_output), return_predict=cnn_output, ) progress_bar.hide() # determine scale for output axes scale_in_dict = dict(zip(axes, image.scale)) scale_out = [scale_in_dict.get(a, 1.0) for a in axes_out] layers = [] if cnn_output: (labels, polys), cnn_out = pred prob, dist = cnn_out[:2] dist = np.moveaxis(dist, -1, 0) assert len(model.config.grid) == len(model.config.axes) - 1 grid_dict = dict( zip(model.config.axes.replace("C", ""), model.config.grid)) # scale output axes to match input axes _scale = [ s * grid_dict.get(a, 1) for a, s in zip(axes_out, scale_out) ] # small translation correction if grid > 1 (since napari centers objects) _translate = [0.5 * (grid_dict.get(a, 1) - 1) for a in axes_out] layers.append(( dist, dict( name="StarDist distances", scale=[1] + _scale, translate=[0] + _translate, **lkwargs, ), "image", )) layers.append(( prob, dict( name="StarDist probability", scale=_scale, translate=_translate, **lkwargs, ), "image", )) else: labels, polys = pred if output_type in (Output.Labels.value, Output.Both.value): layers.append(( labels, dict(name="StarDist labels", scale=scale_out, opacity=0.5, **lkwargs), "labels", )) if output_type in (Output.Polys.value, Output.Both.value): n_objects = len(polys["points"]) if isinstance(model, StarDist3D): surface = surface_from_polys(polys) layers.append(( surface, dict( name="StarDist polyhedra", contrast_limits=(0, surface[-1].max()), scale=scale_out, colormap=label_colormap(n_objects), **lkwargs, ), "surface", )) else: # TODO: sometimes hangs for long time (indefinitely?) when returning many polygons (?) # seems to be a known issue: https://github.com/napari/napari/issues/2015 # TODO: coordinates correct or need offset (0.5 or so)? shapes = np.moveaxis(polys["coord"], -1, -2) layers.append(( shapes, dict( name="StarDist polygons", shape_type="polygon", scale=scale_out, edge_width=0.75, edge_color="yellow", face_color=[0, 0, 0, 0], **lkwargs, ), "shapes", )) return layers
def is_valid(self, return_invalid=False): """Check if configuration is valid. Returns ------- bool Flag that indicates whether the current configuration values are valid. """ def _is_int(v, low=None, high=None): return ( isinstance(v, int) and (True if low is None else low <= v) and (True if high is None else v <= high) ) ok = {} ok['n_dim'] = self.n_dim in (2, 3) try: axes_check_and_normalize(self.axes, self.n_dim + 1, disallowed='S') ok['axes'] = True except: ok['axes'] = False ok['n_channel_in'] = _is_int(self.n_channel_in, 1) ok['n_channel_out'] = _is_int(self.n_channel_out, 4) ok['train_loss'] = ( (self.train_loss in ('seg', 'denoiseg')) ) ok['unet_n_depth'] = _is_int(self.unet_n_depth, 1) ok['relative_weights'] = isinstance(self.relative_weights, list) and len(self.relative_weights) == 3 and all( x > 0 for x in self.relative_weights) ok['unet_kern_size'] = _is_int(self.unet_kern_size, 1) ok['unet_n_first'] = _is_int(self.unet_n_first, 1) ok['unet_last_activation'] = self.unet_last_activation in ('linear', 'relu') ok['probabilistic'] = isinstance(self.probabilistic, bool) and not self.probabilistic ok['unet_residual'] = isinstance(self.unet_residual, bool) and not self.unet_residual ok['unet_input_shape'] = ( isinstance(self.unet_input_shape, (list, tuple)) and len(self.unet_input_shape) == self.n_dim + 1 and self.unet_input_shape[-1] == self.n_channel_in and all((d is None or (_is_int(d) and d % (2 ** self.unet_n_depth) == 0) for d in self.unet_input_shape[:-1])) ) ok['train_epochs'] = _is_int(self.train_epochs, 1) ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch, 1) ok['train_learning_rate'] = np.isscalar(self.train_learning_rate) and self.train_learning_rate > 0 ok['train_batch_size'] = _is_int(self.train_batch_size, 1) ok['train_tensorboard'] = isinstance(self.train_tensorboard, bool) ok['train_checkpoint'] = self.train_checkpoint is None or isinstance(self.train_checkpoint, string_types) ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance(self.train_reduce_lr, dict) and self.train_reduce_lr['monitor'] in ['val_loss', 'val_seg_loss', 'val_denoise_loss'] ok['batch_norm'] = isinstance(self.batch_norm, bool) ok['n2v_perc_pix'] = self.n2v_perc_pix > 0 and self.n2v_perc_pix <= 100 ok['n2v_patch_shape'] = ( isinstance(self.n2v_patch_shape, (list, tuple)) and len(self.n2v_patch_shape) == self.n_dim and all(d > 0 for d in self.n2v_patch_shape) ) ok['n2v_manipulator'] = self.n2v_manipulator in ['normal_withoutCP', 'uniform_withCP', 'normal_additive', 'normal_fitted', 'identity'] ok['n2v_neighborhood_radius'] = _is_int(self.n2v_neighborhood_radius, 0) ok['denoiseg_alpha'] = isinstance(self.denoiseg_alpha, float) and self.denoiseg_alpha >= 0.0 and self.denoiseg_alpha <= 1.0 if return_invalid: return all(ok.values()), tuple(k for (k, v) in ok.items() if not v) else: return all(ok.values())
def __init__(self, X, **kwargs): """See class docstring""" # X is empty if config is None if X.size != 0: assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes means, stds = [], [] for i in range(X.shape[-1]): means.append(np.mean(X[..., i])) stds.append(np.std(X[..., i])) # normalization parameters self.means = [str(el) for el in means] self.stds = [str(el) for el in stds] # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters self.n_channel_in = 1 self.n_channel_out = 4 self.train_loss = 'denoiseg' # default config (can be overwritten by kwargs below) self.unet_n_depth = 4 self.relative_weights = [1.0, 1.0, 5.0] self.unet_kern_size = 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,) self.train_epochs = 200 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 128 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10} self.batch_norm = True self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (64, 64) if self.n_dim == 2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 self.denoiseg_alpha = 0.5 # disallow setting 'probabilistic' manually try: kwargs['probabilistic'] = False except: pass # disallow setting 'unet_residual' manually try: kwargs['unet_residual'] = False except: pass for k in kwargs: setattr(self, k, kwargs[k])
def is_valid(self, return_invalid=False): """Check if configuration is valid. Returns ------- bool Flag that indicates whether the current configuration values are valid. """ def _is_int(v, low=None, high=None): return (isinstance(v, int) and (True if low is None else low <= v) and (True if high is None else v <= high)) ok = {} ok['means'] = True for mean in self.means: ok['means'] &= np.isscalar(float(mean)) ok['stds'] = True for std in self.stds: ok['stds'] &= np.isscalar(float(std)) and float(std) > 0.0 ok['n_dim'] = self.n_dim in (2, 3) try: axes_check_and_normalize(self.axes, self.n_dim + 1, disallowed='S') ok['axes'] = True except: ok['axes'] = False ok['n_channel_in'] = _is_int(self.n_channel_in, 1) ok['n_channel_out'] = _is_int(self.n_channel_out, 1) ok['unet_residual'] = (isinstance(self.unet_residual, bool) and (not self.unet_residual or (self.n_channel_in == self.n_channel_out))) ok['unet_n_depth'] = _is_int(self.unet_n_depth, 1) ok['unet_kern_size'] = _is_int(self.unet_kern_size, 1) ok['unet_n_first'] = _is_int(self.unet_n_first, 1) ok['unet_last_activation'] = self.unet_last_activation in ('linear', 'relu') ok['unet_input_shape'] = ( isinstance(self.unet_input_shape, (list, tuple)) and len(self.unet_input_shape) == self.n_dim + 1 and self.unet_input_shape[-1] == self.n_channel_in and all( (d is None or (_is_int(d) and d % (2**self.unet_n_depth) == 0) for d in self.unet_input_shape[:-1]))) ok['train_loss'] = ((self.train_loss in ('mse', 'mae'))) ok['train_epochs'] = _is_int(self.train_epochs, 1) ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch, 1) ok['train_learning_rate'] = np.isscalar( self.train_learning_rate) and self.train_learning_rate > 0 ok['train_batch_size'] = _is_int(self.train_batch_size, 1) ok['train_tensorboard'] = isinstance(self.train_tensorboard, bool) ok['train_checkpoint'] = self.train_checkpoint is None or isinstance( self.train_checkpoint, string_types) ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance( self.train_reduce_lr, dict) ok['batch_norm'] = isinstance(self.batch_norm, bool) ok['n2v_perc_pix'] = self.n2v_perc_pix > 0 and self.n2v_perc_pix <= 100 ok['n2v_patch_shape'] = (isinstance(self.n2v_patch_shape, (list, tuple)) and len(self.n2v_patch_shape) == self.n_dim and all(d > 0 for d in self.n2v_patch_shape)) ok['n2v_manipulator'] = self.n2v_manipulator in [ 'normal_withoutCP', 'uniform_withCP', 'normal_additive', 'normal_fitted', 'identity' ] ok['n2v_neighborhood_radius'] = _is_int(self.n2v_neighborhood_radius, 0) ok['single_net_per_channel'] = isinstance(self.single_net_per_channel, bool) if self.structN2Vmask is None: ok['structN2Vmask'] = True else: mask = np.array(self.structN2Vmask) t1 = mask.ndim == self.n_dim t2 = all(x % 2 == 1 for x in mask.shape) t3 = all([x in [0, 1] for x in mask.flat]) ok['structN2Vmask'] = t1 and t2 and t3 if return_invalid: return all(ok.values()), tuple(k for (k, v) in ok.items() if not v) else: return all(ok.values())
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None): n_train, n_val = len(X), len(validation_data[0]) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.05f%% of all images)" % (100 * frac_val)) axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) div_by = 2**self.config.unet_n_depth axes_relevant = ''.join(a for a in 'XYZT' if a in axes) val_num_pix = 1 train_num_pix = 1 val_patch_shape = () for a in axes_relevant: n = X.shape[ax[a]] val_num_pix *= validation_data[0].shape[ax[a]] train_num_pix *= X.shape[ax[a]] val_patch_shape += tuple([validation_data[0].shape[ax[a]]]) if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axes %s" " (axis %s has incompatible size %d)" % (div_by, axes_relevant, a, n)) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() manipulator = eval('pm_{0}({1})'.format( self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius))) means = np.array([float(mean) for mean in self.config.means], ndmin=len(X.shape), dtype=np.float32) stds = np.array([float(std) for std in self.config.stds], ndmin=len(X.shape), dtype=np.float32) X = self.__normalize__(X, means, stds) validation_X = self.__normalize__(validation_data[0], means, stds) # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating. training_data = DenoiSeg_DataWrapper( X=X, n2v_Y=np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')), seg_Y=Y, batch_size=self.config.train_batch_size, perc_pix=self.config.n2v_perc_pix, shape=self.config.n2v_patch_shape, value_manipulation=manipulator) # validation_Y is also validation_X plus a concatenated masking channel. # To speed things up, we precompute the masking vo the validation data. validation_Y = np.concatenate( (validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C')) n2v_utils.manipulate_val_data(validation_X, validation_Y, perc_pix=self.config.n2v_perc_pix, shape=val_patch_shape, value_manipulation=manipulator) validation_Y = np.concatenate((validation_Y, validation_data[1]), axis=-1) history = self.keras_model.fit(training_data, validation_data=(validation_X, validation_Y), epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def main(): if not ('__file__' in locals() or '__file__' in globals()): print('running interactively, exiting.') sys.exit(0) # parse arguments parser, args = parse_args() args_dict = vars(args) # exit and show help if no arguments provided at all if len(sys.argv) == 1: parser.print_help() sys.exit(0) # check for required arguments manually (because of argparse issue) required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax', '--model-basedir', '--model-name', '--output-dir') for r in required: dest = r[2:].replace('-', '_') if args_dict[dest] is None: parser.print_usage(file=sys.stderr) print("%s: error: the following arguments are required: %s" % (parser.prog, r), file=sys.stderr) sys.exit(1) # show effective arguments (including defaults) if not args.quiet: print('Arguments') print('---------') pprint(args_dict) print() sys.stdout.flush() # logging function log = (lambda *a, **k: None) if args.quiet else tqdm.write # get list of input files and exit if there are none file_list = list(Path(args.input_dir).glob(args.input_pattern)) if len(file_list) == 0: log("No files to process in '%s' with pattern '%s'." % (args.input_dir, args.input_pattern)) sys.exit(0) # delay imports after checking to all required arguments are provided from tifffile import imread, imsave from csbdeep.utils.tf import keras_import K = keras_import('backend') from csbdeep.models import CARE from csbdeep.data import PercentileNormalizer sys.stdout.flush() sys.stderr.flush() # limit gpu memory if args.gpu_memory_limit is not None: from csbdeep.utils.tf import limit_gpu_memory limit_gpu_memory(args.gpu_memory_limit) # create CARE model and load weights, create normalizer K.clear_session() model = CARE(config=None, name=args.model_name, basedir=args.model_basedir) if args.model_weights is not None: print("Loading network weights from '%s'." % args.model_weights) model.load_weights(args.model_weights) normalizer = PercentileNormalizer(pmin=args.norm_pmin, pmax=args.norm_pmax, do_after=args.norm_undo) n_tiles = args.n_tiles if n_tiles is not None and len(n_tiles) == 1: n_tiles = n_tiles[0] processed = [] # process all files for file_in in tqdm(file_list, disable=args.quiet or (n_tiles is not None and np.prod(n_tiles) > 1)): # construct output file name file_out = Path(args.output_dir) / args.output_name.format( file_path=str(file_in.relative_to(args.input_dir).parent), file_name=file_in.stem, file_ext=file_in.suffix, model_name=args.model_name, model_weights=Path(args.model_weights).stem if args.model_weights is not None else None) # checks (file_in.suffix.lower() in ('.tif', '.tiff') and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise( ValueError('only tiff files supported.')) # load and predict restored image img = imread(str(file_in)) restored = model.predict(img, axes=args.input_axes, normalizer=normalizer, n_tiles=n_tiles) # restored image could be multi-channel even if input image is not axes_out = axes_check_and_normalize(args.input_axes) if restored.ndim > img.ndim: assert restored.ndim == img.ndim + 1 assert 'C' not in axes_out axes_out += 'C' # convert data type (if necessary) restored = restored.astype(np.dtype(args.output_dtype), copy=False) # save to disk if not args.dry_run: file_out.parent.mkdir(parents=True, exist_ok=True) if args.imagej_tiff: save_tiff_imagej_compatible(str(file_out), restored, axes_out) else: imsave(str(file_out), restored) processed.append((file_in, file_out)) # print summary of processed files if not args.quiet: sys.stdout.flush() sys.stderr.flush() n_processed = len(processed) len_processed = len(str(n_processed)) log('Finished processing %d %s' % (n_processed, 'files' if n_processed > 1 else 'file')) log('-' * (26 + len_processed if n_processed > 1 else 26)) for i, (file_in, file_out) in enumerate(processed): len_file = max(len(str(file_in)), len(str(file_out))) log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format( 1 + i, str(file_in))) log(('{:>%d} out: {:>%d}' % (len_processed, len_file)).format( '', str(file_out)))
def predict(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs): """Predict. Parameters ---------- img : :class:`numpy.ndarray` Input image axes : str or None Axes of the input ``img``. ``None`` denotes that axes of img are the same as denoted in the config. normalizer : :class:`csbdeep.data.Normalizer` or None (Optional) normalization of input image before prediction. Note that the default (``None``) assumes ``img`` to be already normalized. n_tiles : iterable or None Out of memory (OOM) errors can occur if the input image is too large. To avoid this problem, the input image is broken up into (overlapping) tiles that are processed independently and re-assembled. This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). ``None`` denotes that no tiling should be used. show_tile_progress: bool Whether to show progress during tiled prediction. predict_kwargs: dict Keyword arguments for ``predict`` function of Keras model. Returns ------- (:class:`numpy.ndarray`,:class:`numpy.ndarray`) Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances. """ if n_tiles is None: n_tiles = [1] * img.ndim try: n_tiles = tuple(n_tiles) img.ndim == len(n_tiles) or _raise(TypeError()) except TypeError: raise ValueError("n_tiles must be an iterable of length %d" % img.ndim) all(np.isscalar(t) and 1 <= t and int(t) == t for t in n_tiles) or _raise( ValueError( "all values of n_tiles must be integer values >= 1")) n_tiles = tuple(map(int, n_tiles)) if axes is None: axes = self.config.axes assert 'C' in axes if img.ndim == len(axes) - 1 and self.config.n_channel_in == 1: # img has no dedicated channel axis, but 'C' always part of config axes axes = axes.replace('C', '') axes = axes_check_and_normalize(axes, img.ndim) axes_net = self.config.axes _permute_axes = self._make_permute_axes(axes, axes_net) x = _permute_axes(img) # x has axes_net semantics channel = axes_dict(axes_net)['C'] self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) axes_net_div_by = self._axes_div_by(axes_net) grid = tuple(self.config.grid) len(grid) == len(axes_net) - 1 or _raise(ValueError()) grid_dict = dict(zip(axes_net.replace('C', ''), grid)) normalizer = self._check_normalizer_resizer(normalizer, None)[0] resizer = StarDistPadAndCropResizer(grid=grid_dict) x = normalizer.before(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) def predict_direct(tile): sh = list(tile.shape) sh[channel] = 1 dummy = np.empty(sh, np.float32) prob, dist = self.keras_model.predict( [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs) return prob[0], dist[0] if np.prod(n_tiles) > 1: tiling_axes = axes_net.replace('C', '') # axes eligible for tiling x_tiling_axis = tuple( axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x axes_net_tile_overlaps = self._axes_tile_overlap(axes_net) # hack: permute tiling axis in the same way as img -> x was permuted n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape (all(n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or _raise( ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes))) sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)] sh[channel] = 1 prob = np.empty(sh, np.float32) sh[channel] = self.config.n_rays dist = np.empty(sh, np.float32) n_block_overlaps = [ int(np.ceil(overlap / blocksize)) for overlap, blocksize in zip(axes_net_tile_overlaps, axes_net_div_by) ] for tile, s_src, s_dst in tqdm(tile_iterator( x, n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps), disable=(not show_tile_progress), total=np.prod(n_tiles)): prob_tile, dist_tile = predict_direct(tile) # account for grid s_src = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_src, axes_net) ] s_dst = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_dst, axes_net) ] # prob and dist have different channel dimensionality than image x s_src[channel] = slice(None) s_dst[channel] = slice(None) s_src, s_dst = tuple(s_src), tuple(s_dst) # print(s_src,s_dst) prob[s_dst] = prob_tile[s_src] dist[s_dst] = dist_tile[s_src] else: prob, dist = predict_direct(x) prob = resizer.after(prob, axes_net) dist = resizer.after(dist, axes_net) dist = np.maximum( 1e-3, dist ) # avoid small/negative dist values to prevent problems with Qhull prob = np.take(prob, 0, axis=channel) dist = np.moveaxis(dist, channel, -1) return prob, dist
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. Y : :class:`numpy.ndarray` Array of target images. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of arrays for source and target validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ leave_center = self.config.n2v_leave_center scale_augmentation = self.config.scale_aug ## Resize validation if necessary.... print((np.sum( np.abs( np.array(validation_data[0][1:-1]) - np.array(self.config.n2v_patch_shape))) != 0)) if (np.sum( np.abs( np.array(validation_data[0][1:-1]) - np.array(self.config.n2v_patch_shape))) != 0): X_val = subpatch_2D(validation_data[0], np.array(self.config.n2v_patch_shape)) Y_val = subpatch_2D(validation_data[1], np.array(self.config.n2v_patch_shape)) validation_data = (X_val, Y_val) ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) n_train, n_val = len(X), len(validation_data[0]) ## Warning about validation size frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.1f%% of all images)" % (100 * frac_val)) #axes description axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) # for a,div_by in zip(axes,self._axes_div_by(axes)): # n = X.shape[ax[a]] # if n % div_by != 0: # raise ValueError( # "training images must be evenly divisible by %d along axis %s" # " (which has incompatible size %d)" % (div_by,a,n) # ) ## ToDO: what is this?? div_by = 2**self.config.unet_n_depth axes_relevant = ''.join(a for a in 'XYZT' if a in axes) val_num_pix = 1 train_num_pix = 1 val_patch_shape = () for a in axes_relevant: n = X.shape[ax[a]] val_num_pix *= validation_data[0].shape[ax[a]] train_num_pix *= X.shape[ax[a]] val_patch_shape += tuple([validation_data[0].shape[ax[a]]]) if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axes %s" " (axis %s has incompatible size %d)" % (div_by, axes_relevant, a, n)) # epochs & steps per epochs if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() # if (self.config.train_tensorboard and self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks)): # self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data, # log_dir=str(self.logdir/'logs'/'images'), # n_images=3, prob_out=self.config.probabilistic)) # # training_data = DataWrapper(X, Y, self.config.train_batch_size,epochs*steps_per_epoch) manipulator = eval('pm_{0}({1})'.format( self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius))) if self.config.normalizer is 'std': means = np.array([float(mean) for mean in self.config.means], ndmin=len(X.shape), dtype=np.float32) stds = np.array([float(std) for std in self.config.stds], ndmin=len(X.shape), dtype=np.float32) X = self.__normalize__(X, means, stds) validation_X = self.__normalize__(validation_data[0], means, stds) else: validation_X = validation_data[0] # Todo: validation normalization if we have; also pick type of normalization as an option #mask (struct to inpaint) _mask = np.array( self.config.structN2Vmask) if self.config.structN2Vmask else None # print(_mask,self.config.channel_denoised) training_data = BioSeg_DataWrapper( X, Y, self.config.train_batch_size, self.config.n2v_perc_pix, self.config.n2v_patch_shape, manipulator, structN2Vmask=_mask, chan_denoise=self.config.channel_denoised, multiple_objectives=self.config.multi_objective, leave_center=leave_center, scale_augmentation=scale_augmentation) # validation_Y is also validation_X plus a concatenated masking channel. # To speed things up, we precompute the masking vo the validation data. if not self.config.channel_denoised: validation_Y = np.concatenate( (validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C')) else: val_aux = validation_data[1][..., 0:X.shape[-1]] # print(val_aux.shape) # if X.shape[-1] == 1: # val_aux = val_aux[...,np.newaxis] validation_Y = np.concatenate( (val_aux, np.zeros(val_aux.shape, dtype=validation_X.dtype)), axis=axes.index('C')) # print(validation_Y.shape, validation_X.shape) manipulate_val_data(validation_X, validation_Y, perc_pix=self.config.n2v_perc_pix, shape=val_patch_shape, value_manipulation=manipulator, chan_denoise=self.config.channel_denoised) # print(self.config) # print(self.config.multi_objective) if self.config.multi_objective: if (self.config.channel_denoised) and ( validation_data[1].shape[-1] > X.shape[-1]): #additional channels validation_Y = np.concatenate( (validation_Y, validation_data[1][..., X.shape[-1]:]), axis=-1) if not self.config.channel_denoised: validation_Y = np.concatenate( (validation_Y, validation_data[1][..., :]), axis=-1) # print(validation_Y.shape, validation_X.shape) fit = self.keras_model.fit_generator # fit = self.keras_model.fit history = fit(training_data, validation_data=(validation_X, validation_Y), epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) ## ToDo : what does this save do if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass #self._training_finished() return history