def _create(red_none, img_size, img_axes, patch_size, patch_axes): raw_data, red_axes, keepdims = get_data(n_images, img_axes, img_size) # change patch_size to (img_size or None) for red_axes patch_size = list(patch_size) for a in red_axes: patch_size[axes_dict( img_axes if patch_axes is None else patch_axes)[a]] = ( None if red_none else img_size[axes_dict(img_axes)[a]]) X, Y, XYaxes = create_patches_reduced_target( raw_data=raw_data, patch_size=patch_size, patch_axes=patch_axes, n_patches_per_image=n_patches_per_image, reduction_axes=red_axes, target_axes=rng.choice( (None, img_axes)) if keepdims else ''.join(a for a in img_axes if a not in red_axes), # normalization=lambda patches_x, patches_y, *args: (patches_x, patches_y), verbose=False, ) assert len(X) == n_images * n_patches_per_image _X = np.mean(X, axis=tuple(axes_dict(XYaxes)[a] for a in red_axes), keepdims=True) err = np.max(np.abs(_X - Y)) assert err < 1e-5
def _guess_n_tiles(self, img): axes = self._normalize_axes(img, axes=None) shape = list(img.shape) if 'C' in axes: del shape[axes_dict(axes)['C']] b = self.config.train_batch_size**(1.0/self.config.n_dim) n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)] if 'C' in axes: n_tiles.insert(axes_dict(axes)['C'],1) return tuple(n_tiles)
def _predict(imdims,axes): img = rng.uniform(size=imdims) n_tiles = [1]*len(axes) ax = axes_dict(axes) if config.probabilistic: prob = model.predict_probabilistic(img, axes, None, None) mean, scale = prob.mean(), prob.scale() assert mean.shape == scale.shape else: mean = model.predict(img, axes, None, None) n_tiles[ax['X']] = 3 n_tiles[ax['Y']] = 2 mean_tiled = model.predict(img, axes, None, None, n_tiles=n_tiles) error_max = np.max(np.abs(mean-mean_tiled)) # print(n_tiles, error_max) assert error_max < 1e-3 with pytest.raises(ValueError): n_tiles[ax[proj_axis]] = 2 model.predict(img, axes, None, None, n_tiles=n_tiles) shape_out = list(imdims) if 'C' in axes: shape_out[ax['C']] = config.n_channel_out elif config.n_channel_out > 1: shape_out.append(config.n_channel_out) del shape_out[ax[proj_axis]] assert tuple(shape_out) == mean.shape
def prepare_input(self, img, axes='YX', normalizer=None): """ experimental function to extract normalized/resized input for prediction on hidden layers ## basically a copy of the original predict function currently only assume 2D n_tiles = (1,1) """ n_tiles = [1] * img.ndim n_tiles = tuple(map(int, n_tiles)) axes = self._normalize_axes(img, axes) axes_net = self.config.axes _permute_axes = self._make_permute_axes(axes, axes_net) x = _permute_axes(img) # x has axes_net semantics channel = axes_dict(axes_net)['C'] self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) axes_net_div_by = self._axes_div_by(axes_net) grid = tuple(self.config.grid) len(grid) == len(axes_net) - 1 or _raise(ValueError()) grid_dict = dict(zip(axes_net.replace('C', ''), grid)) normalizer = self._check_normalizer_resizer(normalizer, None)[0] resizer = StarDistPadAndCropResizer(grid=grid_dict) x = normalizer.before(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) sh = list(x.shape) sh[channel] = 1 dummy = np.empty(sh, np.float32) return x[np.newaxis], dummy[np.newaxis]
def _gen(): for i in range(n_images): x = rng.uniform(size=shape) y = np.mean(x, axis=tuple(axes_dict(axes)[a] for a in red_axes), keepdims=keepdims) yield x, y, axes, None
def original(): mypath = Path('isonet_psf_1') mypath.mkdir(exist_ok=True) # sys.stdout = open(mypath / 'train_stdout.txt', 'w') # sys.stderr = open(mypath / 'train_stderr.txt', 'w') (X, Y), (X_val, Y_val), data_axes = load_training_data( mypath / 'my_training_data.npz', validation_split=0.1) ax = axes_dict(data_axes) n_train, n_val = len(X), len(X_val) image_size = tuple( X.shape[i] for i in ((ax['Z'], ax['Y'], ax['X']) if (ax['Z'] is not None) else (ax['Y'], ax['X']))) n_dim = len(image_size) n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']] print('number of training images:\t', n_train) print('number of validation images:\t', n_val) print('image size (%dD):\t\t' % n_dim, image_size) print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out) plt.figure(figsize=(10, 4)) plot_some(X_val[:5], Y_val[:5]) plt.suptitle( '5 example validation patches (top row: source, bottom row: target)') plt.savefig(mypath / 'train_1.png') config = Config(data_axes, n_channel_in, n_channel_out, train_epochs=200) print(config) vars(config) model = IsotropicCARE(config, str(mypath / 'my_model')) history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.savefig(mypath / 'train_history.png') model.load_weights() # load best weights according to validation loss plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) if config.probabilistic: _P = _P[..., :(_P.shape[-1] // 2)] plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5) plt.suptitle('5 example validation patches\n' + 'top row: input (source), ' + 'middle row: target (ground truth), ' + 'bottom row: predicted from source') plt.tight_layout() plt.savefig(mypath / 'train_2.png') model.export_TF()
def train(self, channels=None, **config_args): #limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, **config_args) # Training model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models') # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return #print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle('5 example validation patches\n' 'top row: input (source), ' 'middle row: target (ground truth), ' 'bottom row: predicted from source') plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")
def _predict(imdims,axes): img = rng.uniform(size=imdims) if config.probabilistic: prob = model.predict_probabilistic(img, axes, factor, None, None) mean, scale = prob.mean(), prob.scale() assert mean.shape == scale.shape else: mean = model.predict(img, axes, factor, None, None) a = axes_dict(axes)['Z'] assert imdims[a]*factor == mean.shape[a]
def dev(args): import json # Load and parse training data (X, Y), (X_val, Y_val), axes = load_training_data(args.train_data, validation_split=args.valid_split, axes=args.axes, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # Model config print('args.resume: ', args.resume) if args.resume: # If resuming, config=None will reload the saved config config = None print('Attempting to resume') elif args.config: print('loading config from args') config_args = json.load(open(args.config)) config = Config(**config_args) else: config = Config(axes, n_channel_in, n_channel_out, probabilistic=args.prob, train_steps_per_epoch=args.steps, train_epochs=args.epochs) print(vars(config)) # Load or init model model = CARE(config, args.model_name, basedir='models') # Training, tensorboard available history = model.train(X, Y, validation_data=(X_val, Y_val)) # Plot training results print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.savefig(args.model_name + '_training.png') # Export model to be used w/ csbdeep fiji plugins and KNIME flows model.export_TF()
def _predict(imdims, axes): img = rng.uniform(size=imdims) # print(img.shape, axes, config.n_channel_out) mean, scale = model._predict_mean_and_scale(img, axes, normalizer, resizer) if config.probabilistic: assert mean.shape == scale.shape else: assert scale is None if 'C' not in axes: if config.n_channel_out == 1: assert mean.shape == img.shape else: assert mean.shape == img.shape + (config.n_channel_out, ) else: channel = axes_dict(axes)['C'] imdims[channel] = config.n_channel_out assert mean.shape == tuple(imdims)
def __init__(self, X, **kwargs): """See class docstring""" # X is empty if config is None if X.size != 0: assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes means, stds = [], [] for i in range(X.shape[-1]): means.append(np.mean(X[..., i])) stds.append(np.std(X[..., i])) # normalization parameters self.means = [str(el) for el in means] self.stds = [str(el) for el in stds] # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters self.n_channel_in = 1 self.n_channel_out = 4 self.train_loss = 'denoiseg' # default config (can be overwritten by kwargs below) self.unet_n_depth = 4 self.relative_weights = [1.0, 1.0, 5.0] self.unet_kern_size = 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,) self.train_epochs = 200 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 128 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10} self.batch_norm = True self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (64, 64) if self.n_dim == 2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 self.denoiseg_alpha = 0.5 # disallow setting 'probabilistic' manually try: kwargs['probabilistic'] = False except: pass # disallow setting 'unet_residual' manually try: kwargs['unet_residual'] = False except: pass for k in kwargs: setattr(self, k, kwargs[k])
def __init__(self, X, **kwargs): """See class docstring.""" assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 n_channel_in = X.shape[-1] n_channel_out = n_channel_in mean = np.mean(X) std = np.std(X) if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S','') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C'+axes # normalization parameters self.mean = str(mean) self.std = str(std) # directly set by parameters self.n_dim = n_dim self.axes = axes self.n_channel_in = int(n_channel_in) self.n_channel_out = int(n_channel_out) # default config (can be overwritten by kwargs below) self.unet_residual = False self.unet_n_depth = 2 self.unet_kern_size = 5 if self.n_dim==2 else 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' if backend_channels_last(): self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,) self.train_loss = 'mae' self.train_epochs = 100 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 16 self.train_tensorboard = True self.train_checkpoint = 'weights_best.h5' self.train_reduce_lr = {'factor': 0.5, 'patience': 10} self.batch_norm = True self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 # disallow setting 'n_dim' manually try: del kwargs['n_dim'] # warnings.warn("ignoring parameter 'n_dim'") except: pass for k in kwargs: setattr(self, k, kwargs[k])
def train(self, X,Y, validation_data, augmenter=None, seed=None, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of input images. Y : :class:`numpy.ndarray` Array of label masks. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of X,Y validation arrays. augmenter : None or callable Function with expected signature ``Xbt, Ybt = augmenter(Xb, Yb)`` that takes in batch input/label images (Xb,Yb) and returns transformed images (Xbt, Ybt) for the purpose of data augmentation during training. Not applied to validation images. seed : int Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.) epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ if seed is not None: # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development np.random.seed(seed) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch validation_data is not None or _raise(ValueError()) ((isinstance(validation_data,(list,tuple)) and len(validation_data)==2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) patch_size = self.config.train_patch_size axes = self.config.axes.replace('C','') div_by = self._axes_div_by(axes) [p % d == 0 or _raise(ValueError( "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d) )) for p,d,a in zip(patch_size,div_by,axes)] if not self._model_prepared: self.prepare_for_training() data_kwargs = dict ( rays = rays_from_json(self.config.rays_json), grid = self.config.grid, patch_size = self.config.train_patch_size, anisotropy = self.config.anisotropy, use_gpu = self.config.use_gpu, ) # generate validation data and store in numpy arrays # data_val = StarDistData3D(*validation_data, batch_size=1, augment=False, **data_kwargs) _data_val = StarDistData3D(*validation_data, batch_size=1, **data_kwargs) n_data_val = len(_data_val) n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val ids = tuple(np.random.choice(n_data_val, size=n_take, replace=(n_take > n_data_val))) Xv, Mv, Pv, Dv = [None]*n_take, [None]*n_take, [None]*n_take, [None]*n_take for i,k in enumerate(ids): (Xv[i],Mv[i]),(Pv[i],Dv[i]) = _data_val[k] Xv, Mv, Pv, Dv = np.concatenate(Xv,axis=0), np.concatenate(Mv,axis=0), np.concatenate(Pv,axis=0), np.concatenate(Dv,axis=0) data_val = [[Xv,Mv],[Pv,Dv]] data_train = StarDistData3D(X, Y, batch_size=self.config.train_batch_size, augmenter=augmenter, **data_kwargs) for cb in self.callbacks: if isinstance(cb,CARETensorBoard): # only show middle slice of 3D inputs/outputs cb.input_slices, cb.output_slices = [[slice(None)]*5,[slice(None)]*5], [[slice(None)]*5,[slice(None)]*5] i = axes_dict(self.config.axes)['Z'] _n_in = _data_val.patch_size[i] // 2 _n_out = _data_val.patch_size[i] // (2 * (self.config.grid[i] if self.config.grid is not None else 1)) cb.input_slices[0][1+i] = _n_in cb.input_slices[1][1+i] = _n_out cb.output_slices[0][1+i] = _n_out cb.output_slices[1][1+i] = _n_out # show dist for three rays _n = min(3, self.config.n_rays) cb.output_slices[1][1+axes_dict(self.config.axes)['C']] = slice(0,(self.config.n_rays//_n)*_n,self.config.n_rays//_n) history = self.keras_model.fit_generator(generator=data_train, validation_data=data_val, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) self._training_finished() return history
def plugin( viewer: napari.Viewer, label_head, image: napari.layers.Image, axes, label_nn, model_type, model2d, model3d, model_folder, model_axes, norm_image, perc_low, perc_high, input_scale, label_nms, prob_thresh, nms_thresh, output_type, label_adv, n_tiles, norm_axes, timelapse_opts, cnn_output, set_thresholds, defaults_button, progress_bar: mw.ProgressBar, ) -> List[napari.types.LayerDataTuple]: model = get_model(*model_selected) if model._is_multiclass(): warn( "multi-class mode not supported yet, ignoring classification output" ) lkwargs = {} x = get_data(image) axes = axes_check_and_normalize(axes, length=x.ndim) if not (input_scale is None or isinstance(input_scale, numbers.Number)): input_scale = tuple(s for a, s in zip(axes, input_scale) if a not in ("T", )) # print(f'scaling by {input_scale}') if not axes.replace("T", "").startswith( model._axes_out.replace("C", "")): warn( f"output images have different axes ({model._axes_out.replace('C','')}) than input image ({axes})" ) # TODO: adjust image.scale according to shuffled axes if norm_image: axes_norm = axes_check_and_normalize(norm_axes) axes_norm = "".join(set(axes_norm).intersection( set(axes))) # relevant axes present in input image assert len(axes_norm) > 0 # always jointly normalize channels for RGB images if ("C" in axes and image.rgb == True) and ("C" not in axes_norm): axes_norm = axes_norm + "C" warn("jointly normalizing channels of RGB input image") ax = axes_dict(axes) _axis = tuple(sorted(ax[a] for a in axes_norm)) # # TODO: address joint vs. channel/time-separate normalization properly (let user choose) # # also needs to be documented somewhere # if 'T' in axes: # if 'C' not in axes or image.rgb == True: # # normalize channels jointly, frames independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],)) # else: # # normalize channels independently, frames independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],ax['C'])) # else: # if 'C' not in axes or image.rgb == True: # # normalize channels jointly # _axis = None # else: # # normalize channels independently # _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],)) x = normalize(x, perc_low, perc_high, axis=_axis) # TODO: progress bar (labels) often don't show up. events not processed? if "T" in axes: app = use_app() t = axes_dict(axes)["T"] n_frames = x.shape[t] if n_tiles is not None: # remove tiling value for time axis n_tiles = tuple(v for i, v in enumerate(n_tiles) if i != t) def progress(it, **kwargs): progress_bar.label = "StarDist Prediction (frames)" progress_bar.range = (0, n_frames) progress_bar.value = 0 progress_bar.show() app.process_events() for item in it: yield item progress_bar.increment() app.process_events() app.process_events() elif n_tiles is not None and np.prod(n_tiles) > 1: n_tiles = tuple(n_tiles) app = use_app() def progress(it, **kwargs): progress_bar.label = "CNN Prediction (tiles)" progress_bar.range = (0, kwargs.get("total", 0)) progress_bar.value = 0 progress_bar.show() app.process_events() for item in it: yield item progress_bar.increment() app.process_events() # progress_bar.label = "NMS Postprocessing" progress_bar.range = (0, 0) app.process_events() else: progress = False progress_bar.label = "StarDist Prediction" progress_bar.range = (0, 0) progress_bar.show() use_app().process_events() # semantic output axes of predictions assert model._axes_out[-1] == "C" axes_out = list(model._axes_out[:-1]) if "T" in axes: x_reorder = np.moveaxis(x, t, 0) axes_reorder = axes.replace("T", "") axes_out.insert(t, "T") res = tuple( zip(*tuple( model.predict_instances( _x, axes=axes_reorder, prob_thresh=prob_thresh, nms_thresh=nms_thresh, n_tiles=n_tiles, scale=input_scale, sparse=(not cnn_output), return_predict=cnn_output, ) for _x in progress(x_reorder)))) if cnn_output: labels, polys = tuple(zip(*res[0])) cnn_output = tuple(np.stack(c, t) for c in tuple(zip(*res[1]))) else: labels, polys = res labels = np.asarray(labels) if len(polys) > 1: if timelapse_opts == TimelapseLabels.Match.value: # match labels in consecutive frames (-> simple IoU tracking) labels = group_matching_labels(labels) elif timelapse_opts == TimelapseLabels.Unique.value: # make label ids unique (shift by offset) offsets = np.cumsum([len(p["points"]) for p in polys]) for y, off in zip(labels[1:], offsets): y[y > 0] += off elif timelapse_opts == TimelapseLabels.Separate.value: # each frame processed separately (nothing to do) pass else: raise NotImplementedError( f"unknown option '{timelapse_opts}' for time-lapse labels" ) labels = np.moveaxis(labels, 0, t) if isinstance(model, StarDist3D): # TODO poly output support for 3D timelapse polys = None else: polys = dict( coord=np.concatenate( tuple( np.insert(p["coord"], t, _t, axis=-2) for _t, p in enumerate(polys)), axis=0, ), points=np.concatenate( tuple( np.insert(p["points"], t, _t, axis=-1) for _t, p in enumerate(polys)), axis=0, ), ) if cnn_output: pred = (labels, polys), cnn_output else: pred = labels, polys else: # TODO: possible to run this in a way that it can be canceled? pred = model.predict_instances( x, axes=axes, prob_thresh=prob_thresh, nms_thresh=nms_thresh, n_tiles=n_tiles, show_tile_progress=progress, scale=input_scale, sparse=(not cnn_output), return_predict=cnn_output, ) progress_bar.hide() # determine scale for output axes scale_in_dict = dict(zip(axes, image.scale)) scale_out = [scale_in_dict.get(a, 1.0) for a in axes_out] layers = [] if cnn_output: (labels, polys), cnn_out = pred prob, dist = cnn_out[:2] dist = np.moveaxis(dist, -1, 0) assert len(model.config.grid) == len(model.config.axes) - 1 grid_dict = dict( zip(model.config.axes.replace("C", ""), model.config.grid)) # scale output axes to match input axes _scale = [ s * grid_dict.get(a, 1) for a, s in zip(axes_out, scale_out) ] # small translation correction if grid > 1 (since napari centers objects) _translate = [0.5 * (grid_dict.get(a, 1) - 1) for a in axes_out] layers.append(( dist, dict( name="StarDist distances", scale=[1] + _scale, translate=[0] + _translate, **lkwargs, ), "image", )) layers.append(( prob, dict( name="StarDist probability", scale=_scale, translate=_translate, **lkwargs, ), "image", )) else: labels, polys = pred if output_type in (Output.Labels.value, Output.Both.value): layers.append(( labels, dict(name="StarDist labels", scale=scale_out, opacity=0.5, **lkwargs), "labels", )) if output_type in (Output.Polys.value, Output.Both.value): n_objects = len(polys["points"]) if isinstance(model, StarDist3D): surface = surface_from_polys(polys) layers.append(( surface, dict( name="StarDist polyhedra", contrast_limits=(0, surface[-1].max()), scale=scale_out, colormap=label_colormap(n_objects), **lkwargs, ), "surface", )) else: # TODO: sometimes hangs for long time (indefinitely?) when returning many polygons (?) # seems to be a known issue: https://github.com/napari/napari/issues/2015 # TODO: coordinates correct or need offset (0.5 or so)? shapes = np.moveaxis(polys["coord"], -1, -2) layers.append(( shapes, dict( name="StarDist polygons", shape_type="polygon", scale=scale_out, edge_width=0.75, edge_color="yellow", face_color=[0, 0, 0, 0], **lkwargs, ), "shapes", )) return layers
from __future__ import print_function, unicode_literals, absolute_import, division import numpy as np from tifffile import imread from csbdeep.utils import axes_dict, plot_some, plot_history from csbdeep.utils.tf import limit_gpu_memory from csbdeep.io import load_training_data from csbdeep.models import Config, ProjectionCARE # In[2]: (X, Y), (X_val, Y_val), axes = load_training_data( '/local/u934/private/v_kapoor/CurieTrainingDatasets/Drosophilla/DenoisingProjection.npz', validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # In[3]: # In[4]: config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=4, train_epochs=50, train_steps_per_epoch=400, train_batch_size=16, train_reduce_lr={ 'patience': 5,
def train(Training_source=".", Training_target=".", model_name="No_name", model_path=".", Visual_validation_after_training=True, number_of_epochs=100, patch_size=64, number_of_patches=10, Use_Default_Advanced_Parameters=True, number_of_steps=300, batch_size=32, percentage_validation=15): ''' Main function of the script. Train the model an save in model_path Parameters ---------- Training_source : (str) Path to the noisy images Training_target : (str) Path to the GT images model_name : (str) name of the model model_path : (str) path of the model Visual_validation_after_training : (bool) Predict a random image after training Number_of_epochs : (int) epochs path_size : (int) patch sizes number_of_patches : (int) number of patches for each image User_Default_Advances_Parameters : (bool) Use default parameters for the training number_of_steps : (int) number of steps batch_size : (int) batch size percentage_validation : (int) percentage validation Return ------- void ''' OutputFile = Training_target + "/*.tif" InputFile = Training_source + "/*.tif" base = "/content/" training_data = base + "/my_training_data.npz" if (Use_Default_Advanced_Parameters): print("Default advanced parameters enabled") batch_size = 64 percentage_validation = 10 percentage = percentage_validation / 100 #here we check that no model with the same name already exist, if so delete if os.path.exists(model_path + '/' + model_name): shutil.rmtree(model_path + '/' + model_name) # The shape of the images. x = imread(InputFile) y = imread(OutputFile) print('Loaded Input images (number, width, length) =', x.shape) print('Loaded Output images (number, width, length) =', y.shape) print("Parameters initiated.") # RawData Object # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images. # This file is saved in .npz format and later called when loading the trainig data. raw_data = data.RawData.from_folder(basepath=base, source_dirs=[Training_source], target_dir=Training_target, axes='CYX', pattern='*.tif*') X, Y, XY_axes = data.create_patches(raw_data, patch_filter=None, patch_size=(patch_size, patch_size), n_patches_per_image=number_of_patches) print('Creating 2D training dataset') training_path = model_path + "/rawdata" rawdata1 = training_path + ".npz" np.savez(training_path, X=X, Y=Y, axes=XY_axes) # Load Training Data (X, Y), (X_val, Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] #Show_patches(X,Y) #Here we automatically define number_of_step in function of training data and batch size if (Use_Default_Advanced_Parameters): number_of_steps = int(X.shape[0] / batch_size) + 1 print(number_of_steps) #Here we create the configuration file config = Config(axes, n_channel_in, n_channel_out, probabilistic=False, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=0.0004) print(config) vars(config) # Compile the CARE model for network training model_training = CARE(config, model_name, basedir=model_path) if (Visual_validation_after_training): Cell_executed = 1 import time start = time.time() #@markdown ##Start Training # Start Training history = model_training.train(X, Y, validation_data=(X_val, Y_val)) print("Training, done.") if (Visual_validation_after_training): Predict_a_image(Training_source, Training_target, model_path, model_training) # Displaying the time elapsed for training dt = time.time() - start min, sec = divmod(dt, 60) hour, min = divmod(min, 60) print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec), "sec(s)") Show_loss_function(history, model_path)
def __init__(self, X,**kwargs): if X.size != 0: assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes means, stds = [], [] for i in range(X.shape[-1]): means.append(np.mean(X[..., i])) stds.append(np.std(X[..., i])) # normalization parameters self.means = [str(el) for el in means] self.stds = [str(el) for el in stds] # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters if 'C' in axes: self.n_channel_in = X.shape[-1] else: self.n_channel_in = 1 self.train_loss = 'demix' # default config (can be overwritten by kwargs below) self.unet_n_depth = 2 self.unet_kern_size = 3 self.unet_n_first = 64 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None,) + (self.n_channel_in,) else: self.unet_input_shape = (self.n_channel_in,) + self.n_dim * (None,) # fixed parameters self.train_epochs = 200 self.train_steps_per_epoch = 50 self.train_learning_rate = 0.0004 self.train_batch_size = 64 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'monitor': 'val_loss', 'factor': 0.5, 'patience': 10} self.batch_norm = False self.n2v_perc_pix = 1.5 self.n2v_patch_shape = (128, 128) if self.n_dim == 2 else (64, 64, 64) self.n2v_manipulator = 'uniform_withCP' self.n2v_neighborhood_radius = 5 self.single_net_per_channel = False self.channel_denoised = False self.multi_objective = True self.normalizer = 'none' self.weights_objectives = [0.1, 0, 0.45, 0.45] self.distributions = 'gauss' self.n2v_leave_center = False self.scale_aug = False self.structN2Vmask = None self.n_back_modes = 2 self.n_fore_modes = 2 self.n_instance_seg = 0 self.n_back_i_modes = 1 self.n_fore_i_modes = 1 self.fit_std = False self.fit_mean = True # self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components +\ # self.n_instance_seg * (self.n_back_i_modes+self.n_fore_i_modes) try: kwargs['probabilistic'] = False except: pass # disallow setting 'unet_residual' manually try: kwargs['unet_residual'] = False except: pass # print('KWARGS') for k in kwargs: # print(k, kwargs[k]) setattr(self, k, kwargs[k]) self.n_components = 3 if (self.fit_std & self.fit_mean) else 2 self.n_channel_out = (self.n_back_modes + self.n_fore_modes) * self.n_channel_in * self.n_components + \ self.n_instance_seg * (self.n_back_i_modes + self.n_fore_i_modes)
def __init__(self, X, **kwargs): """See class docstring""" assert X.size != 0 assert len(X.shape) == 4 or len( X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported." n_dim = len(X.shape) - 2 if n_dim == 2: axes = 'SYXC' elif n_dim == 3: axes = 'SZYXC' # parse and check axes axes = axes_check_and_normalize(axes) ax = axes_dict(axes) ax = {a: (ax[a] is not None) for a in ax} (ax['X'] and ax['Y']) or _raise( ValueError('lateral axes X and Y must be present.')) not (ax['Z'] and ax['T']) or _raise( ValueError('using Z and T axes together not supported.')) axes.startswith('S') or (not ax['S']) or _raise( ValueError('sample axis S must be first.')) axes = axes.replace('S', '') # remove sample axis if it exists if backend_channels_last(): if ax['C']: axes[-1] == 'C' or _raise( ValueError('channel axis must be last for backend (%s).' % K.backend())) else: axes += 'C' else: if ax['C']: axes[0] == 'C' or _raise( ValueError('channel axis must be first for backend (%s).' % K.backend())) else: axes = 'C' + axes assert X.shape[-1] == 1 n_channel_in = 1 n_channel_out = 3 # directly set by parameters self.n_dim = n_dim self.axes = axes # fixed parameters self.n_channel_in = n_channel_in self.n_channel_out = n_channel_out self.train_loss = 'seg' # default config (can be overwritten by kwargs below) self.unet_n_depth = 4 self.relative_weights = [1.0, 1.0, 5.0] self.unet_kern_size = 3 self.unet_n_first = 32 self.unet_last_activation = 'linear' self.probabilistic = False self.unet_residual = False if backend_channels_last(): self.unet_input_shape = self.n_dim * (None, ) + ( self.n_channel_in, ) else: self.unet_input_shape = ( self.n_channel_in, ) + self.n_dim * (None, ) self.train_epochs = 200 self.train_steps_per_epoch = 400 self.train_learning_rate = 0.0004 self.train_batch_size = 128 self.train_tensorboard = False self.train_checkpoint = 'weights_best.h5' self.train_checkpoint_last = 'weights_last.h5' self.train_checkpoint_epoch = 'weights_now.h5' self.train_reduce_lr = {'factor': 0.5, 'patience': 10} self.batch_norm = True # disallow setting 'n_dim' manually try: del kwargs['n_dim'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'n_channel_in' manually try: del kwargs['n_channel_in'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'n_channel_out' manually try: del kwargs['n_channel_out'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'train_loss' manually try: del kwargs['train_loss'] # warnings.warn("ignoring parameter 'n_dim'") except: pass # disallow setting 'probabilistic' manually try: del kwargs['probabilistic'] except: pass # disallow setting 'unet_residual' manually try: del kwargs['unet_residual'] except: pass for k in kwargs: setattr(self, k, kwargs[k])
def _update(self): # try to get a hold of the viewer (can be None when plugin starts) if self.viewer is None: # TODO: when is this not safe to do and will hang forever? # while plugin.viewer.value is None: # time.sleep(0.01) if plugin.viewer.value is not None: self.viewer = plugin.viewer.value if DEBUG: print("GOT viewer") @self.viewer.layers.events.removed.connect def _layer_removed(event): layers_remaining = event.source if len(layers_remaining) == 0: plugin.image.tooltip = "" plugin.axes.value = "" plugin.n_tiles.value = "None" plugin.input_scale.value = "None" def _model(valid): widgets_valid( plugin.model2d, plugin.model3d, plugin.model_folder.line_edit, valid=valid, ) if valid: config = self.args.model axes = config.get("axes", "ZYXC"[-len(config["net_input_shape"]):]) if "T" in axes: raise RuntimeError("model with axis 'T' not supported") plugin.model_axes.value = axes.replace( "C", f"C[{config['n_channel_in']}]") plugin.model_folder.line_edit.tooltip = "" return axes, config else: plugin.model_axes.value = "" plugin.model_folder.line_edit.tooltip = "Invalid model directory" def _image_axes(valid): axes, image, err = getattr(self.args, "image_axes", (None, None, None)) widgets_valid( plugin.axes, valid=(valid or (image is None and (axes is None or len(axes) == 0))), ) if (valid and "T" in axes and plugin.output_type.value in (Output.Polys.value, Output.Both.value)): plugin.output_type.native.setStyleSheet( "background-color: orange") plugin.output_type.tooltip = ( "Displaying many polygons/polyhedra can be very slow.") else: plugin.output_type.native.setStyleSheet("") plugin.output_type.tooltip = "" if valid: plugin.axes.tooltip = "\n".join([ f"{a} = {s}" for a, s in zip(axes, get_data(image).shape) ]) return axes, image else: if err is not None: err = str(err) err = err[:-1] if err.endswith(".") else err plugin.axes.tooltip = err # warn(err) # alternative to tooltip (gui doesn't show up in ipython) else: plugin.axes.tooltip = "" def _norm_axes(valid): norm_axes, err = getattr(self.args, "norm_axes", (None, None)) widgets_valid(plugin.norm_axes, valid=valid) if valid: plugin.norm_axes.tooltip = f"Axes to jointly normalize (if present in selected input image). Note: channels of RGB images are always normalized together." return norm_axes else: if err is not None: err = str(err) err = err[:-1] if err.endswith(".") else err plugin.norm_axes.tooltip = err # warn(err) # alternative to tooltip (gui doesn't show up in ipython) else: plugin.norm_axes.tooltip = "" def _n_tiles(valid): n_tiles, image, err = getattr(self.args, "n_tiles", (None, None, None)) widgets_valid(plugin.n_tiles, valid=(valid or image is None)) if valid: plugin.n_tiles.tooltip = ( "no tiling" if n_tiles is None else "\n".join([ f"{t}: {s}" for t, s in zip(n_tiles, get_data(image).shape) ])) return n_tiles else: msg = str(err) if err is not None else "" plugin.n_tiles.tooltip = msg def _no_tiling_for_axis(axes_image, n_tiles, axis): if n_tiles is not None and axis in axes_image: return n_tiles[axes_dict(axes_image)[axis]] == 1 return True def _input_scale(valid): input_scale, image, err = getattr(self.args, "input_scale", (None, None, None)) widgets_valid(plugin.input_scale, valid=(valid or image is None)) if valid: if input_scale is None: plugin.input_scale.tooltip = "no scaling" elif isinstance(input_scale, numbers.Number): plugin.input_scale.tooltip = f"{input_scale} for all spatial axes" else: assert len(input_scale) == len(get_data(image).shape) plugin.input_scale.tooltip = "\n".join( [f"{s}" for s in input_scale]) return input_scale else: msg = str(err) if err is not None else "" plugin.input_scale.tooltip = msg def _input_scale_check(axes_image, input_scale): if input_scale is not None and not isinstance( input_scale, numbers.Number): assert len(input_scale) == len(axes_image) # s != 1 only allowed for spatial axes XYZ return all(s == 1 or a in "XYZ" for a, s in zip(axes_image, input_scale)) return True def _restore(): widgets_valid(plugin.image, valid=plugin.image.value is not None) all_valid = False help_msg = "" if (self.valid.image_axes and self.valid.n_tiles and self.valid.model and self.valid.norm_axes and self.valid.input_scale): axes_image, image = _image_axes(True) axes_model, config = _model(True) axes_norm = _norm_axes(True) n_tiles = _n_tiles(True) input_scale = _input_scale(True) if not _no_tiling_for_axis(axes_image, n_tiles, "C"): # check if image axes and n_tiles are compatible widgets_valid(plugin.n_tiles, valid=False) err = "number of tiles must be 1 for C axis" plugin.n_tiles.tooltip = err _restore() elif not _no_tiling_for_axis(axes_image, n_tiles, "T"): # check if image axes and n_tiles are compatible widgets_valid(plugin.n_tiles, valid=False) err = "number of tiles must be 1 for T axis" plugin.n_tiles.tooltip = err _restore() elif not _input_scale_check(axes_image, input_scale): # check if image axes and input_scale are compatible widgets_valid(plugin.input_scale, valid=False) _violations = ", ".join( a for a, s in zip(axes_image, input_scale) if not (s == 1 or a in "XYZ")) err = f"values for non-spatial axes ({_violations}) must be 1" plugin.input_scale.tooltip = err _restore() elif set(axes_norm).isdisjoint(set(axes_image)): # check if image axes and normalization axes are compatible widgets_valid(plugin.norm_axes, valid=False) err = f"Image axes ({axes_image}) must contain at least one of the normalization axes ({', '.join(axes_norm)})" plugin.norm_axes.tooltip = err _restore() elif ("T" in axes_image and config.get("n_dim") == 3 and plugin.output_type.value in (Output.Polys.value, Output.Both.value)): # not supported widgets_valid(plugin.output_type, valid=False) plugin.output_type.tooltip = ( "Polyhedra output currently not supported for 3D timelapse data" ) _restore() else: # tooltip for input_scale if isinstance(input_scale, numbers.Number): plugin.input_scale.tooltip = "\n".join([ f'{a} = {input_scale if a in "XYZ" else 1}' for a in axes_image ]) elif input_scale is not None: plugin.input_scale.tooltip = "\n".join([ f"{a} = {s}" for a, s in zip(axes_image, input_scale) ]) # check if image and model are compatible ch_model = config["n_channel_in"] ch_image = ( get_data(image).shape[axes_dict(axes_image)["C"]] if "C" in axes_image else 1) all_valid = (set(axes_model.replace("C", "")) == set( axes_image.replace("C", "").replace("T", "")) and ch_model == ch_image) widgets_valid( plugin.image, plugin.model2d, plugin.model3d, plugin.model_folder.line_edit, valid=all_valid, ) if all_valid: help_msg = "" else: help_msg = f'Model with axes {axes_model.replace("C", f"C[{ch_model}]")} and image with axes {axes_image.replace("C", f"C[{ch_image}]")} not compatible' else: _image_axes(self.valid.image_axes) _norm_axes(self.valid.norm_axes) _n_tiles(self.valid.n_tiles) _input_scale(self.valid.input_scale) _model(self.valid.model) _restore() self.help(help_msg) plugin.call_button.enabled = all_valid # widgets_valid(plugin.call_button, valid=all_valid) if self.debug: print( f"valid ({all_valid}):", ", ".join( [f"{k}={v}" for k, v in vars(self.valid).items()]), )
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. Y : :class:`numpy.ndarray` Array of target images. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of arrays for source and target validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) n_train, n_val = len(X), len(validation_data[0]) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.1f%% of all images)" % (100 * frac_val)) axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) for a, div_by in zip(axes, self._axes_div_by(axes)): n = X.shape[ax[a]] if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axis %s" " (which has incompatible size %d)" % (div_by, a, n)) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() training_data = CryoDataWrapper(X, Y, self.config.train_batch_size) history = self.keras_model.fit_generator( generator=training_data, validation_data=validation_data, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None): n_train, n_val = len(X), len(validation_data[0]) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn( "small number of validation images (only %.05f%% of all images)" % (100 * frac_val)) axes = axes_check_and_normalize('S' + self.config.axes, X.ndim) ax = axes_dict(axes) div_by = 2**self.config.unet_n_depth axes_relevant = ''.join(a for a in 'XYZT' if a in axes) val_num_pix = 1 train_num_pix = 1 val_patch_shape = () for a in axes_relevant: n = X.shape[ax[a]] val_num_pix *= validation_data[0].shape[ax[a]] train_num_pix *= X.shape[ax[a]] val_patch_shape += tuple([validation_data[0].shape[ax[a]]]) if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axes %s" " (axis %s has incompatible size %d)" % (div_by, axes_relevant, a, n)) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() manipulator = eval('pm_{0}({1})'.format( self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius))) means = np.array([float(mean) for mean in self.config.means], ndmin=len(X.shape), dtype=np.float32) stds = np.array([float(std) for std in self.config.stds], ndmin=len(X.shape), dtype=np.float32) X = self.__normalize__(X, means, stds) validation_X = self.__normalize__(validation_data[0], means, stds) # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating. training_data = DenoiSeg_DataWrapper( X=X, n2v_Y=np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')), seg_Y=Y, batch_size=self.config.train_batch_size, perc_pix=self.config.n2v_perc_pix, shape=self.config.n2v_patch_shape, value_manipulation=manipulator) # validation_Y is also validation_X plus a concatenated masking channel. # To speed things up, we precompute the masking vo the validation data. validation_Y = np.concatenate( (validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C')) n2v_utils.manipulate_val_data(validation_X, validation_Y, perc_pix=self.config.n2v_perc_pix, shape=val_patch_shape, value_manipulation=manipulator) validation_Y = np.concatenate((validation_Y, validation_data[1]), axis=-1) history = self.keras_model.fit(training_data, validation_data=(validation_X, validation_Y), epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def load_training_data_direct(X, Y, validation_split=0, axes=None, n_images=None, verbose=False): """Load training data from file in ``.npz`` format. The data file is expected to have the keys: - ``X`` : Array of training input images. - ``Y`` : Array of corresponding target images. - ``axes`` : Axes of the training images. Parameters ---------- file : str File name validation_split : float Fraction of images to use as validation set during training. axes: str, optional Must be provided in case the loaded data does not contain ``axes`` information. n_images : int, optional Can be used to limit the number of images loaded from data. verbose : bool, optional Can be used to display information about the loaded images. Returns ------- tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str ) Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets and the axes of the input images. The tuple of validation data will be ``None`` if ``validation_split = 0``. """ # f = np.load(file) # X, Y = f['X'], f['Y'] # if axes is None: # axes = f['axes'] axes = axes_check_and_normalize(axes) assert X.shape == Y.shape assert len(axes) == X.ndim assert 'C' in axes if n_images is None: n_images = X.shape[0] assert X.shape[0] == Y.shape[0] assert 0 < n_images <= X.shape[0] assert 0 <= validation_split < 1 X, Y = X[:n_images], Y[:n_images] channel = axes_dict(axes)['C'] if validation_split > 0: n_val = int(round(n_images * validation_split)) n_train = n_images - n_val assert 0 < n_val and 0 < n_train X_t, Y_t = X[-n_val:], Y[-n_val:] X, Y = X[:n_train], Y[:n_train] assert X.shape[0] == n_train and X_t.shape[0] == n_val X_t = move_channel_for_backend(X_t, channel=channel) Y_t = move_channel_for_backend(Y_t, channel=channel) X = move_channel_for_backend(X, channel=channel) Y = move_channel_for_backend(Y, channel=channel) axes = axes.replace('C', '') # remove channel if backend_channels_last(): axes = axes + 'C' else: axes = axes[:1] + 'C' + axes[1:] data_val = (X_t, Y_t) if validation_split > 0 else None if verbose: ax = axes_dict(axes) n_train, n_val = len(X), len(X_t) if validation_split > 0 else 0 image_size = tuple(X.shape[ax[a]] for a in 'TZYX' if a in axes) n_dim = len(image_size) n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']] print('number of training images:\t', n_train) print('number of validation images:\t', n_val) print('image size (%dD):\t\t' % n_dim, image_size) print('axes:\t\t\t\t', axes) print('channels in / out:\t\t', n_channel_in, '/', n_channel_out) return (X, Y), data_val, axes
def train(self, X, Y, validation_data, augmenter=None, seed=None, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` Input images Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` Label masks validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of X,Y validation arrays. augmenter : None or callable Function with expected signature ``xt, yt = augmenter(x, y)`` that takes in a single pair of input/label image (x,y) and returns the transformed images (xt, yt) for the purpose of data augmentation during training. Not applied to validation images. Example: def simple_augmenter(x,y): x = x + 0.05*np.random.normal(0,1,x.shape) return x,y seed : int Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.) epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ if seed is not None: # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development np.random.seed(seed) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch validation_data is not None or _raise(ValueError()) ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) patch_size = self.config.train_patch_size axes = self.config.axes.replace('C', '') b = self.config.train_completion_crop if self.config.train_shape_completion else 0 div_by = self._axes_div_by(axes) [(p - 2 * b) % d == 0 or _raise( ValueError( "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'" .format(a=a, d=d) if self.config.train_shape_completion else "'train_patch_size' must be divisible by {d} along axis '{a}'". format(a=a, d=d))) for p, d, a in zip(patch_size, div_by, axes)] if not self._model_prepared: self.prepare_for_training() data_kwargs = dict( n_params=self.config.n_params, patch_size=self.config.train_patch_size, grid=self.config.grid, shape_completion=self.config.train_shape_completion, b=self.config.train_completion_crop, use_gpu=self.config.use_gpu, foreground_prob=self.config.train_foreground_only, contoursize_max=self.config.contoursize_max, ) # generate validation data and store in numpy arrays n_data_val = len(validation_data[0]) n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val _data_val = SplineDistData2D(*validation_data, batch_size=n_take, length=1, **data_kwargs) data_val = _data_val[0] data_train = SplineDistData2D(X, Y, batch_size=self.config.train_batch_size, augmenter=augmenter, length=epochs * steps_per_epoch, **data_kwargs) if self.config.train_tensorboard: # show dist for three rays _n = min(3, self.config.n_params) channel = axes_dict(self.config.axes)['C'] output_slices = [[slice(None)] * 4, [slice(None)] * 4] output_slices[1][1 + channel] = slice( 0, (self.config.n_params // _n) * _n, self.config.n_params // _n) if IS_TF_1: for cb in self.callbacks: if isinstance(cb, CARETensorBoard): cb.output_slices = output_slices # target image for dist includes dist_mask and thus has more channels than dist output cb.output_target_shapes = [None, [None] * 4] cb.output_target_shapes[1][ 1 + channel] = data_val[1][1].shape[1 + channel] elif self.basedir is not None and not any( isinstance(cb, CARETensorBoardImage) for cb in self.callbacks): self.callbacks.append( CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir / 'logs' / 'images'), n_images=3, prob_out=False, output_slices=output_slices)) fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit history = fit(iter(data_train), validation_data=data_val, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) self._training_finished() return history
def train(self, X, validation_X, epochs=None, steps_per_epoch=None): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of source images. validation_x : :class:`numpy.ndarray` Array of validation images. epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ n_train, n_val = len(X), len(validation_X) frac_val = (1.0 * n_val) / (n_train + n_val) frac_warn = 0.05 if frac_val < frac_warn: warnings.warn("small number of validation images (only %.1f%% of all images)" % (100*frac_val)) axes = axes_check_and_normalize('S'+self.config.axes,X.ndim) ax = axes_dict(axes) div_by = 2**self.config.unet_n_depth axes_relevant = ''.join(a for a in 'XYZT' if a in axes) val_num_pix = 1 train_num_pix = 1 val_patch_shape = () for a in axes_relevant: n = X.shape[ax[a]] val_num_pix *= validation_X.shape[ax[a]] train_num_pix *= X.shape[ax[a]] val_patch_shape += tuple([validation_X.shape[ax[a]]]) if n % div_by != 0: raise ValueError( "training images must be evenly divisible by %d along axes %s" " (axis %s has incompatible size %d)" % (div_by,axes_relevant,a,n) ) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch if not self._model_prepared: self.prepare_for_training() manipulator = eval('pm_{0}({1})'.format(self.config.n2v_manipulator, str(self.config.n2v_neighborhood_radius))) mean, std = float(self.config.mean), float(self.config.std) X = self.__normalize__(X, mean, std) validation_X = self.__normalize__(validation_X, mean, std) # Here we prepare the Noise2Void data. Our input is the noisy data X and as target we take X concatenated with # a masking channel. The N2V_DataWrapper will take care of the pixel masking and manipulating. training_data = N2V_DataWrapper(X, np.concatenate((X, np.zeros(X.shape, dtype=X.dtype)), axis=axes.index('C')), self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix), self.config.n2v_patch_shape, manipulator) # validation_Y is also validation_X plus a concatinated masking channel. # To speed things up, we precomupte the masking vo the validation data. validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C')) n2v_utils.manipulate_val_data(validation_X, validation_Y, num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix), shape=val_patch_shape, value_manipulation=manipulator) history = self.keras_model.fit_generator(generator=training_data, validation_data=(validation_X, validation_Y), epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) if self.basedir is not None: self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) if self.config.train_checkpoint is not None: print() self._find_and_load_weights(self.config.train_checkpoint) try: # remove temporary weights (self.logdir / 'weights_now.h5').unlink() except FileNotFoundError: pass return history
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1): """Train the neural network with the given data. Parameters ---------- X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` Input images Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence` Label masks classes (optional): 'auto' or iterable of same length as X label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0) list of dicts with label id -> class id (1,...,n_classes) 'auto' -> all objects will be assigned to the first non-background class, or will be ignored if config.n_classes is None validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass) Tuple (triple if multiclass) of X,Y,[classes] validation data. augmenter : None or callable Function with expected signature ``xt, yt = augmenter(x, y)`` that takes in a single pair of input/label image (x,y) and returns the transformed images (xt, yt) for the purpose of data augmentation during training. Not applied to validation images. Example: def simple_augmenter(x,y): x = x + 0.05*np.random.normal(0,1,x.shape) return x,y seed : int Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.) epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ if seed is not None: # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development np.random.seed(seed) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch classes = self._parse_classes_arg(classes, len(X)) if not self._is_multiclass() and classes is not None: warnings.warn("Ignoring given classes as n_classes is set to None") isinstance(validation_data, (list, tuple)) or _raise(ValueError()) if self._is_multiclass() and len(validation_data) == 2: validation_data = tuple(validation_data) + ('auto', ) ((len(validation_data) == (3 if self._is_multiclass( ) else 2)) or _raise( ValueError( f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}' ))) patch_size = self.config.train_patch_size axes = self.config.axes.replace('C', '') b = self.config.train_completion_crop if self.config.train_shape_completion else 0 div_by = self._axes_div_by(axes) [(p - 2 * b) % d == 0 or _raise( ValueError( "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'" .format(a=a, d=d) if self.config.train_shape_completion else "'train_patch_size' must be divisible by {d} along axis '{a}'". format(a=a, d=d))) for p, d, a in zip(patch_size, div_by, axes)] if not self._model_prepared: self.prepare_for_training() data_kwargs = dict( n_rays=self.config.n_rays, patch_size=self.config.train_patch_size, grid=self.config.grid, shape_completion=self.config.train_shape_completion, b=self.config.train_completion_crop, use_gpu=self.config.use_gpu, foreground_prob=self.config.train_foreground_only, n_classes=self.config.n_classes, sample_ind_cache=self.config.train_sample_cache, ) # generate validation data and store in numpy arrays n_data_val = len(validation_data[0]) classes_val = self._parse_classes_arg( validation_data[2], n_data_val) if self._is_multiclass() else None n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val _data_val = StarDistData2D(validation_data[0], validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs) data_val = _data_val[0] # expose data generator as member for general diagnostics self.data_train = StarDistData2D( X, Y, classes=classes, batch_size=self.config.train_batch_size, augmenter=augmenter, length=epochs * steps_per_epoch, **data_kwargs) if self.config.train_tensorboard: # show dist for three rays _n = min(3, self.config.n_rays) channel = axes_dict(self.config.axes)['C'] output_slices = [[slice(None)] * 4, [slice(None)] * 4] output_slices[1][1 + channel] = slice( 0, (self.config.n_rays // _n) * _n, self.config.n_rays // _n) if self._is_multiclass(): _n = min(3, self.config.n_classes) output_slices += [[slice(None)] * 4] output_slices[2][1 + channel] = slice( 1, 1 + ((self.config.n_classes + 1) // _n) * _n, self.config.n_classes // _n) if IS_TF_1: for cb in self.callbacks: if isinstance(cb, CARETensorBoard): cb.output_slices = output_slices # target image for dist includes dist_mask and thus has more channels than dist output cb.output_target_shapes = [None, [None] * 4, None] cb.output_target_shapes[1][ 1 + channel] = data_val[1][1].shape[1 + channel] elif self.basedir is not None and not any( isinstance(cb, CARETensorBoardImage) for cb in self.callbacks): self.callbacks.append( CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir / 'logs' / 'images'), n_images=3, prob_out=False, output_slices=output_slices)) fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit history = fit( iter(self.data_train), validation_data=data_val, epochs=epochs, steps_per_epoch=steps_per_epoch, workers=workers, use_multiprocessing=workers > 1, callbacks=self.callbacks, verbose=1, # set validation batchsize to training batchsize (only works for tf >= 2.2) **(dict(validation_batch_size=self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {})) self._training_finished() return history
def train(self, X, Y, validation_data, augmenter=None, seed=None, epochs=None, steps_per_epoch=None, multi=False, ncpu=1): """Train the neural network with the given data. Parameters ---------- X : :class:`numpy.ndarray` Array of input images. Y : :class:`numpy.ndarray` Array of label masks. validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) Tuple of X,Y validation arrays. augmenter : None or callable Function with expected signature ``xt, yt = augmenter(x, y)`` that takes in a single pair of input/label image (x,y) and returns the transformed images (xt, yt) for the purpose of data augmentation during training. Not applied to validation images. Example: def simple_augmenter(x,y): x = x + 0.05*np.random.normal(0,1,x.shape) return x,y seed : int Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.) epochs : int Optional argument to use instead of the value from ``config``. steps_per_epoch : int Optional argument to use instead of the value from ``config``. Returns ------- ``History`` object See `Keras training history <https://keras.io/models/model/#fit>`_. """ if seed is not None: # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development np.random.seed(seed) if epochs is None: epochs = self.config.train_epochs if steps_per_epoch is None: steps_per_epoch = self.config.train_steps_per_epoch validation_data is not None or _raise(ValueError()) ((isinstance(validation_data,(list,tuple)) and len(validation_data)==2) or _raise(ValueError('validation_data must be a pair of numpy arrays'))) patch_size = self.config.train_patch_size axes = self.config.axes.replace('C','') b = self.config.train_completion_crop if self.config.train_shape_completion else 0 div_by = self._axes_div_by(axes) [(p-2*b) % d == 0 or _raise(ValueError( "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d) )) for p,d,a in zip(patch_size,div_by,axes)] if not self._model_prepared: self.prepare_for_training() data_kwargs = dict ( n_rays = self.config.n_rays, patch_size = self.config.train_patch_size, grid = self.config.grid, shape_completion = self.config.train_shape_completion, b = self.config.train_completion_crop, use_gpu = self.config.use_gpu, foreground_prob = self.config.train_foreground_only, prob_thr = self.config.EDT_prob_threshold, border_R = self.config.EDT_border_R, black_border = self.config.EDT_black_border ) # generate validation data and store in numpy arrays data_val = StarDistData2D(*validation_data, batch_size=1, **data_kwargs) n_data_val = len(data_val) n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val ids = tuple(np.random.choice(n_data_val, size=n_take, replace=(n_take > n_data_val))) Xv, Mv, Pv, Dv = [None]*n_take, [None]*n_take, [None]*n_take, [None]*n_take for i,k in enumerate(ids): (Xv[i],Mv[i]),(Pv[i],Dv[i]) = data_val[k] Xv, Mv, Pv, Dv = np.concatenate(Xv,axis=0), np.concatenate(Mv,axis=0), np.concatenate(Pv,axis=0), np.concatenate(Dv,axis=0) data_val = [[Xv,Mv],[Pv,Dv]] data_train = StarDistData2D(X, Y, batch_size=self.config.train_batch_size, augmenter=augmenter, **data_kwargs) for cb in self.callbacks: if isinstance(cb,CARETensorBoard): # show dist for three rays _n = min(3, self.config.n_rays) cb.output_slices = [[slice(None)]*4,[slice(None)]*4] cb.output_slices[1][1+axes_dict(self.config.axes)['C']] = slice(0,(self.config.n_rays//_n)*_n,self.config.n_rays//_n) history = self.keras_model.fit_generator(generator=data_train, validation_data=data_val, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1, use_multiprocessing=multi, workers=ncpu) self._training_finished() return history
def predict_instances_big(self, img, axes, block_size, min_overlap, context=None, labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs): """Predict instance segmentation from very large input images. Intended to be used when `predict_instances` cannot be used due to memory limitations. This function will break the input image into blocks and process them individually via `predict_instances` and assemble all the partial results. If used as intended, the result should be the same as if `predict_instances` was used directly on the whole image. **Important**: The crucial assumption is that all predicted object instances are smaller than the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size. Example ------- >>> img.shape (20000, 20000) >>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096, min_overlap=128, context=128, n_tiles=(4,4)) Parameters ---------- img: :class:`numpy.ndarray` or similar Input image axes: str Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.) block_size: int or iterable of int Process input image in blocks of the provided shape. (If a scalar value is given, it is used for all spatial image dimensions.) min_overlap: int or iterable of int Amount of guaranteed overlap between blocks. (If a scalar value is given, it is used for all spatial image dimensions.) context: int or iterable of int, or None Amount of image context on all sides of a block, which is discarded. If None, uses an automatic estimate that should work in many cases. (If a scalar value is given, it is used for all spatial image dimensions.) labels_out: :class:`numpy.ndarray` or similar, or None, or False numpy array or similar (must be of correct shape) to which the label image is written. If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``. If False, will not write the label image (useful if only the dictionary is needed). labels_out_dtype: str or dtype Data type of returned label image if ``labels_out=None`` (has no effect otherwise). show_progress: bool Show progress bar for block processing. kwargs: dict Keyword arguments for ``predict_instances``. Returns ------- (:class:`numpy.ndarray` or False, dict) Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra. """ from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels from ..matching import relabel_sequential n = img.ndim axes = axes_check_and_normalize(axes, length=n) grid = self._axes_div_by(axes) axes_out = self._axes_out.replace('C','') shape_dict = dict(zip(axes,img.shape)) shape_out = tuple(shape_dict[a] for a in axes_out) if context is None: context = self._axes_tile_overlap(axes) if np.isscalar(block_size): block_size = n*[block_size] if np.isscalar(min_overlap): min_overlap = n*[min_overlap] if np.isscalar(context): context = n*[context] block_size, min_overlap, context = list(block_size), list(min_overlap), list(context) assert n == len(block_size) == len(min_overlap) == len(context) if 'C' in axes: # single block for channel axis i = axes_dict(axes)['C'] # if (block_size[i], min_overlap[i], context[i]) != (None, None, None): # print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " + # "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True) block_size[i] = img.shape[i] min_overlap[i] = context[i] = 0 block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes)) min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes)) context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes)) # print(f"input: shape {img.shape} with axes {axes}") print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True) for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)): if c < o: print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True) # create block cover blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid) if np.isscalar(labels_out) and bool(labels_out) is False: labels_out = None else: if labels_out is None: labels_out = np.zeros(shape_out, dtype=labels_out_dtype) else: labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out}).")) polys_all = {} # problem_ids = [] label_offset = 1 kwargs_override = dict(axes=axes, overlap_label=None) if show_progress: kwargs_override['show_tile_progress'] = False # disable progress for predict_instances for k,v in kwargs_override.items(): if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True) kwargs[k] = v blocks = tqdm(blocks, disable=(not show_progress)) # actual computation for block in blocks: labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs) labels = block.crop_context(labels, axes=axes_out) labels, polys = block.filter_objects(labels, polys, axes=axes_out) # TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset) labels = relabel_sequential(labels, label_offset)[0] # labels, fwd_map, _ = relabel_sequential(labels, label_offset) # if len(incomplete) > 0: # problem_ids.extend([fwd_map[i] for i in incomplete]) # if show_progress: # blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}") if labels_out is not None: block.write(labels_out, labels, axes=axes_out) for k,v in polys.items(): polys_all.setdefault(k,[]).append(v) label_offset += len(polys['prob']) polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()} # if labels_out is not None and len(problem_ids) > 0: # # if show_progress: # # blocks.write('') # # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True) # repaint_labels(labels_out, problem_ids, polys_all, show_progress=False) return labels_out, polys_all#, tuple(problem_ids)
def predict(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs): """Predict. Parameters ---------- img : :class:`numpy.ndarray` Input image axes : str or None Axes of the input ``img``. ``None`` denotes that axes of img are the same as denoted in the config. normalizer : :class:`csbdeep.data.Normalizer` or None (Optional) normalization of input image before prediction. Note that the default (``None``) assumes ``img`` to be already normalized. n_tiles : iterable or None Out of memory (OOM) errors can occur if the input image is too large. To avoid this problem, the input image is broken up into (overlapping) tiles that are processed independently and re-assembled. This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``). ``None`` denotes that no tiling should be used. show_tile_progress: bool Whether to show progress during tiled prediction. predict_kwargs: dict Keyword arguments for ``predict`` function of Keras model. Returns ------- (:class:`numpy.ndarray`,:class:`numpy.ndarray`) Returns the tuple (`prob`, `dist`) of per-pixel object probabilities and star-convex polygon/polyhedra distances. """ if n_tiles is None: n_tiles = [1] * img.ndim try: n_tiles = tuple(n_tiles) img.ndim == len(n_tiles) or _raise(TypeError()) except TypeError: raise ValueError("n_tiles must be an iterable of length %d" % img.ndim) all(np.isscalar(t) and 1 <= t and int(t) == t for t in n_tiles) or _raise( ValueError( "all values of n_tiles must be integer values >= 1")) n_tiles = tuple(map(int, n_tiles)) axes = self._normalize_axes(img, axes) axes_net = self.config.axes _permute_axes = self._make_permute_axes(axes, axes_net) x = _permute_axes(img) # x has axes_net semantics channel = axes_dict(axes_net)['C'] self.config.n_channel_in == x.shape[channel] or _raise(ValueError()) axes_net_div_by = self._axes_div_by(axes_net) grid = tuple(self.config.grid) len(grid) == len(axes_net) - 1 or _raise(ValueError()) grid_dict = dict(zip(axes_net.replace('C', ''), grid)) normalizer = self._check_normalizer_resizer(normalizer, None)[0] resizer = StarDistPadAndCropResizer(grid=grid_dict) x = normalizer.before(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) def predict_direct(tile): sh = list(tile.shape) sh[channel] = 1 dummy = np.empty(sh, np.float32) prob, dist = self.keras_model.predict( [tile[np.newaxis], dummy[np.newaxis]], **predict_kwargs) return prob[0], dist[0] if np.prod(n_tiles) > 1: tiling_axes = axes_net.replace('C', '') # axes eligible for tiling x_tiling_axis = tuple( axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x axes_net_tile_overlaps = self._axes_tile_overlap(axes_net) # hack: permute tiling axis in the same way as img -> x was permuted n_tiles = _permute_axes(np.empty(n_tiles, np.bool)).shape (all(n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or _raise( ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes))) sh = [s // grid_dict.get(a, 1) for a, s in zip(axes_net, x.shape)] sh[channel] = 1 prob = np.empty(sh, np.float32) sh[channel] = self.config.n_rays dist = np.empty(sh, np.float32) n_block_overlaps = [ int(np.ceil(overlap / blocksize)) for overlap, blocksize in zip(axes_net_tile_overlaps, axes_net_div_by) ] for tile, s_src, s_dst in tqdm(tile_iterator( x, n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps), disable=(not show_tile_progress), total=np.prod(n_tiles)): prob_tile, dist_tile = predict_direct(tile) # account for grid s_src = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_src, axes_net) ] s_dst = [ slice(s.start // grid_dict.get(a, 1), s.stop // grid_dict.get(a, 1)) for s, a in zip(s_dst, axes_net) ] # prob and dist have different channel dimensionality than image x s_src[channel] = slice(None) s_dst[channel] = slice(None) s_src, s_dst = tuple(s_src), tuple(s_dst) # print(s_src,s_dst) prob[s_dst] = prob_tile[s_src] dist[s_dst] = dist_tile[s_src] else: prob, dist = predict_direct(x) prob = resizer.after(prob, axes_net) dist = resizer.after(dist, axes_net) dist = np.maximum( 1e-3, dist ) # avoid small/negative dist values to prevent problems with Qhull prob = np.take(prob, 0, axis=channel) dist = np.moveaxis(dist, channel, -1) return prob, dist
def _no_tiling_for_axis(axes_image, n_tiles, axis): if n_tiles is not None and axis in axes_image: return n_tiles[axes_dict(axes_image)[axis]] == 1 return True