def dev(args): import json # Load and parse training data (X, Y), (X_val, Y_val), axes = load_training_data(args.train_data, validation_split=args.valid_split, axes=args.axes, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # Model config print('args.resume: ', args.resume) if args.resume: # If resuming, config=None will reload the saved config config = None print('Attempting to resume') elif args.config: print('loading config from args') config_args = json.load(open(args.config)) config = Config(**config_args) else: config = Config(axes, n_channel_in, n_channel_out, probabilistic=args.prob, train_steps_per_epoch=args.steps, train_epochs=args.epochs) print(vars(config)) # Load or init model model = CARE(config, args.model_name, basedir='models') # Training, tensorboard available history = model.train(X, Y, validation_data=(X_val, Y_val)) # Plot training results print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.savefig(args.model_name + '_training.png') # Export model to be used w/ csbdeep fiji plugins and KNIME flows model.export_TF()
def train(self, channels=None, **config_args): #limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, **config_args) # Training model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models') # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return #print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle('5 example validation patches\n' 'top row: input (source), ' 'middle row: target (ground truth), ' 'bottom row: predicted from source') plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")
def gen_care_single_model(name: str) -> CARE: """ Generate a single channel CARE model or retrieve the one that already exists under the name specified :param name: The name of the model :return: The CARE model """ try: model = CARE(None, name) except FileNotFoundError: config = Config('xyzc', n_channel_in=1, n_channel_out=1) model = CARE(config, name) return model
def gen_care_dual_model(name: str, batch_size: int = 16, **kwargs): """ Generate a dual channel CARE model or retrieve the one that already exists under the name specified :param name: The name of the model :param batch_size: The training batch size to use (only used if the model doesn't exist yet) :param kwargs: Parameters to pass to the model constructor (only used if the model doesn't exist yet :return: The CARE model """ try: model = CARE(None, name) except FileNotFoundError: config = Config('xyzc', n_channel_in=2, n_channel_out=1, train_batch_size=batch_size, **kwargs) model = CARE(config, name) return model
def test_config(): assert K.image_data_format() in ('channels_first', 'channels_last') def _with_channel(axes): axes = axes.upper() if 'C' in axes: return axes return (axes + 'C') if K.image_data_format() == 'channels_last' else ('C' + axes) axes_list = [ ('yx', _with_channel('YX')), ('ytx', _with_channel('YTX')), ('zyx', _with_channel('ZYX')), ('YX', _with_channel('YX')), ('XYZ', _with_channel('XYZ')), ('XYT', _with_channel('XYT')), ('SYX', _with_channel('YX')), ('SXYZ', _with_channel('XYZ')), ('SXTY', _with_channel('XTY')), (_with_channel('YX'), _with_channel('YX')), (_with_channel('XYZ'), _with_channel('XYZ')), (_with_channel('XTY'), _with_channel('XTY')), (_with_channel('SYX'), _with_channel('YX')), (_with_channel('STYX'), _with_channel('TYX')), (_with_channel('SXYZ'), _with_channel('XYZ')), ] for (axes, axes_ref) in axes_list: assert Config(axes).axes == axes_ref with pytest.raises(ValueError): Config('XYC') Config('CXY') with pytest.raises(ValueError): Config('XYZC') Config('CXYZ') with pytest.raises(ValueError): Config('XTYC') Config('CXTY') with pytest.raises(ValueError): Config('XYZT') with pytest.raises(ValueError): Config('tXYZ') with pytest.raises(ValueError): Config('XYS') with pytest.raises(ValueError): Config('XSYZ')
def config_generator(**kwargs): assert 'axes' in kwargs keys, values = kwargs.keys(), kwargs.values() values = [v if isinstance(v, (list, tuple)) else [v] for v in values] for p in product(*values): yield Config(**dict(zip(keys, p)))
plt.figure(figsize=(12, 5)) plot_some(X_val[:5], Y_val[:5]) plt.suptitle( '5 example validation patches (top row: source, bottom row: target)') # In[5]: config = config = Config(axes, n_channel_in, n_channel_out, probabilistic=False, unet_n_depth=5, unet_n_first=48, unet_kern_size=7, train_loss='mae', train_epochs=150, train_learning_rate=1.0E-4, train_batch_size=1, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) # In[6]: model = CARE(config=config, name=ModelName, basedir=ModelDir) #input_weights = ModelDir + ModelName + '/' +'weights_best.h5' #model.load_weights(input_weights)
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
args.add_argument("--models_dir", default="care_probabilistic_models") args = args.parse_args() tr_data, _, tr_axes = load_training_data(args.train_data, validation_split=0) val_data, _, val_axes = load_training_data(args.val_data, validation_split=0) #val_data = np.load(args.val_data) #val_data = (val_data["X"], val_data["Y"]) #val_axes = tr_axes # we assume that both training and validation data are saved in the same format is_3d = bool(args.is_3d) axes = "ZYX" if is_3d else "YX" n_dim = 3 if is_3d else 2 config = Config(axes, n_dim=n_dim, n_channel_in=1, n_channel_out=1, probabilistic=int(args.probabilistic), train_batch_size=int(args.batch_size), unet_kern_size=3) for i in range(int(args.num_models)): model = CARE(config, f"model_{i}", args.models_dir) train_history = model.train(tr_data[0], tr_data[1], validation_data=val_data, epochs=int(args.epochs), steps_per_epoch=int(args.steps_per_epoch)) exit() plot_history(train_history, ["loss", "val_loss"],
c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] if args.plot: plt.figure(figsize=(12, 5)) plot_some(X_val[:5], Y_val[:5]) plt.suptitle( '5 example validation patches (top row: source, bottom row: target)' ) plt.show() # Construct a CARE model, defining its configuration via a Config object config = Config( axes, n_channel_in, n_channel_out, probabilistic=False, # We don't need detailed stats just yet train_steps_per_epoch=args.train_steps_per_epoch, train_epochs=args.num_epochs) print(config) vars(config) model = CARE(config, args.model_name, basedir=args.model_dir) # Use tensorboard to check the training progress with logdir = basedir history = model.train(X, Y, validation_data=(X_val, Y_val)) if args.plot: plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae'])
n_channel_in, n_channel_out = X_train.shape[c], Y_train.shape[c] plt.figure(figsize=(12, 5)) plot_some(X_val[:5], Y_val[:5]) plt.suptitle( '5 example validation patches (top row: source, bottom row: target)') ################# # Configuration # ################# # Config object contains: parameters of the underlying neural network, learning rate, number of parameter updates per epoch, loss function, and whether the model is probabilistic or not. config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=stepPerEpoch) print(config) vars(config) ############ # TRAINING # ############ #Possibility to monitor the progress using TensorBoat (see https://www.tensorflow.org/guide/summaries_and_tensorboard) # model instanciation #model = CARE(config=None, name='my_model', basedir='models') # used to load a model model = CARE(config, modelName, basedir=baseDir) # used to train a new model # training model
def training(self): #Loading Files and Plot examples basepath = 'data/' training_original_dir = 'training/original/' training_ground_truth_dir = 'training/ground_truth/' validation_original_dir = 'validation/original/' validation_ground_truth_dir = 'validation/ground_truth/' import glob from skimage import io from matplotlib import pyplot as plt training_original_files = sorted( glob.glob(basepath + training_original_dir + '*.tif')) training_original_file = io.imread(training_original_files[0]) training_ground_truth_files = sorted( glob.glob(basepath + training_ground_truth_dir + '*.tif')) training_ground_truth_file = io.imread(training_ground_truth_files[0]) print("Training dataset's number of files and dimensions: ", len(training_original_files), training_original_file.shape, len(training_ground_truth_files), training_ground_truth_file.shape) training_size = len(training_original_file) validation_original_files = sorted( glob.glob(basepath + validation_original_dir + '*.tif')) validation_original_file = io.imread(validation_original_files[0]) validation_ground_truth_files = sorted( glob.glob(basepath + validation_ground_truth_dir + '*.tif')) validation_ground_truth_file = io.imread( validation_ground_truth_files[0]) print("Validation dataset's number of files and dimensions: ", len(validation_original_files), validation_original_file.shape, len(validation_ground_truth_files), validation_ground_truth_file.shape) validation_size = len(validation_original_file) if training_size == validation_size: size = training_size else: print( 'Training and validation images should be of the same dimensions!' ) plt.figure(figsize=(16, 4)) plt.subplot(141) plt.imshow(training_original_file) plt.subplot(142) plt.imshow(training_ground_truth_file) plt.subplot(143) plt.imshow(validation_original_file) plt.subplot(144) plt.imshow(validation_ground_truth_file) #preparing inputs for NN from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset from csbdeep.data import RawData, create_patches, no_background_patches training_data = RawData.from_folder( basepath=basepath, source_dirs=[training_original_dir], target_dir=training_ground_truth_dir, axes='YX', ) validation_data = RawData.from_folder( basepath=basepath, source_dirs=[validation_original_dir], target_dir=validation_ground_truth_dir, axes='YX', ) # pathces will be created further in "data augmentation" step, # that's why patch size here is the dimensions of images and number of pathes per image is 1 size1 = 64 X, Y, XY_axes = create_patches( raw_data=training_data, patch_size=(size1, size1), patch_filter=no_background_patches(0), n_patches_per_image=1, save_file=basepath + 'training.npz', ) X_val, Y_val, XY_axes = create_patches( raw_data=validation_data, patch_size=(size1, size1), patch_filter=no_background_patches(0), n_patches_per_image=1, save_file=basepath + 'validation.npz', ) #Loading training and validation data into memory from csbdeep.io import load_training_data (X, Y), _, axes = load_training_data(basepath + 'training.npz', verbose=False) (X_val, Y_val), _, axes = load_training_data(basepath + 'validation.npz', verbose=False) X.shape, Y.shape, X_val.shape, Y_val.shape from csbdeep.utils import axes_dict c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] batch = len( X ) # You should define number of batches according to the available memory #batch=1 seed = 1 from keras.preprocessing.image import ImageDataGenerator data_gen_args = dict(samplewise_center=False, samplewise_std_normalization=False, zca_whitening=False, fill_mode='reflect', rotation_range=30, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, vertical_flip=True #width_shift_range=0.2, #height_shift_range=0.2, ) # training image_datagen = ImageDataGenerator(**data_gen_args) mask_datagen = ImageDataGenerator(**data_gen_args) image_datagen.fit(X, augment=True, seed=seed) mask_datagen.fit(Y, augment=True, seed=seed) image_generator = image_datagen.flow(X, batch_size=batch, seed=seed) mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed) generator = zip(image_generator, mask_generator) # validation image_datagen_val = ImageDataGenerator(**data_gen_args) mask_datagen_val = ImageDataGenerator(**data_gen_args) image_datagen_val.fit(X_val, augment=True, seed=seed) mask_datagen_val.fit(Y_val, augment=True, seed=seed) image_generator_val = image_datagen_val.flow(X_val, batch_size=batch, seed=seed) mask_generator_val = mask_datagen_val.flow(Y_val, batch_size=batch, seed=seed) generator_val = zip(image_generator_val, mask_generator_val) # plot examples x, y = generator.__next__() x_val, y_val = generator_val.__next__() plt.figure(figsize=(16, 4)) plt.subplot(141) plt.imshow(x[0, :, :, 0]) plt.subplot(142) plt.imshow(y[0, :, :, 0]) plt.subplot(143) plt.imshow(x_val[0, :, :, 0]) plt.subplot(144) plt.imshow(y_val[0, :, :, 0]) import os blocks = 2 channels = 16 learning_rate = 0.0004 learning_rate_decay_factor = 0.95 epoch_size_multiplicator = 20 # basically, how often do we decrease learning rate epochs = 10 #comment='_side' # adds to model_name import datetime #model path model_path = f'models/CSBDeep/model.h5' if os.path.isfile(model_path): print('Your model will be overwritten in the next cell') kernel_size = 3 from csbdeep.models import Config, CARE from keras import backend as K best_mae = 1 steps_per_epoch = len(X) * epoch_size_multiplicator validation_steps = len(X_val) * epoch_size_multiplicator if 'model' in globals(): del model if os.path.isfile(model_path): os.remove(model_path) for i in range(epochs): print('Epoch:', i + 1) learning_rate = learning_rate * learning_rate_decay_factor config = Config(axes, n_channel_in, n_channel_out, unet_kern_size=kernel_size, train_learning_rate=learning_rate, unet_n_depth=blocks, unet_n_first=channels) model = CARE(config, '.', basedir='models') #os.remove('models/config.json') if i > 0: model.keras_model.load_weights(model_path) model.prepare_for_training() history = model.keras_model.fit_generator( generator, validation_data=generator_val, validation_steps=validation_steps, epochs=1, verbose=0, shuffle=True, steps_per_epoch=steps_per_epoch) if history.history['val_mae'][0] < best_mae: best_mae = history.history['val_mae'][0] if not os.path.exists('models/'): os.makedirs('models/') model.keras_model.save(model_path) print(f'Validation MAE:{best_mae:.3E}') del model K.clear_session() #model path model_path = 'models/CSBDeep/model.h5' config = Config(axes, n_channel_in, n_channel_out, unet_kern_size=kernel_size, train_learning_rate=learning_rate, unet_n_depth=blocks, unet_n_first=channels) model = CARE(config, '.', basedir='models') model.keras_model.load_weights(model_path) model.export_TF()
def n2v_flim(project, n2v_num_pix=32): results_file = os.path.join(project, 'fit_results.hdf5') X, groups, mask = extract_results(results_file) data_shape = np.shape(X) print(data_shape) mean, std = np.mean(X), np.std(X) X = normalize(X, mean, std) XA = X #augment_data(X) X_val = X[0:10,...] # We concatenate an extra channel filled with zeros. It will be internally used for the masking. Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1) Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) n_x = X.shape[1] n_chan = X.shape[-1] manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x)) # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. config = Config('SYXC', n_channel_in=n_chan, n_channel_out=n_chan, unet_kern_size = 5, unet_n_depth = 2, train_steps_per_epoch=200, train_loss='mae', train_epochs=35, batch_norm = False, train_scheme = 'Noise2Void', train_batch_size = 128, n2v_num_pix = n2v_num_pix, n2v_patch_shape = (n2v_num_pix, n2v_num_pix), n2v_manipulator = 'uniform_withCP', n2v_neighborhood_radius='5') vars(config) model = CARE(config, 'n2v_model', basedir=project) history = model.train(XA, Y, validation_data=(X_val,Y_val)) model.load_weights(name='weights_best.h5') output_project = project.replace('.flimfit','-n2v.flimfit') if os.path.exists(output_project) : shutil.rmtree(output_project) shutil.copytree(project, output_project) output_file = os.path.join(output_project, 'fit_results.hdf5') X_pred = np.zeros(X.shape) for i in range(X.shape[0]): X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std) X_pred[mask] = np.NaN insert_results(output_file, X_pred, groups)
keep_idx = np.isin(chans, keepers) if np.sum(keep_idx) == 0: raise ValueError("Did not supply valid channel name") print('analyzing the following channels: {}'.format(chans[keep_idx])) x_train, x_test = x_train[:, :, :, keep_idx], x_test[:, :, :, keep_idx] y_train, y_test = y_train[:, :, :, keep_idx], y_test[:, :, :, keep_idx] # Taken from CARE FAQ, runs the model axes = 'SYXC' c = axes_dict(axes)['C'] n_channel_in, n_channel_out = x_train.shape[c], y_train.shape[c] config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_epochs=num_epochs, train_steps_per_epoch=30) model = CARE(config, model_name, basedir='models') history = model.train(x_train, y_train, validation_data=(x_test, y_test)) fig = plt.figure(figsize=(30, 30)) _P = model.keras_model.predict(x_test[:5, :, :, :]) _P_mean = _P[..., :(_P.shape[-1] // 2)] _P_scale = _P[..., (_P.shape[-1] // 2):] plot_some(x_test[:5, :, :, 0], y_test[:5, :, :, :], _P_mean, _P_scale, pmax=99.5)
def train(Training_source=".", Training_target=".", model_name="No_name", model_path=".", Visual_validation_after_training=True, number_of_epochs=100, patch_size=64, number_of_patches=10, Use_Default_Advanced_Parameters=True, number_of_steps=300, batch_size=32, percentage_validation=15): ''' Main function of the script. Train the model an save in model_path Parameters ---------- Training_source : (str) Path to the noisy images Training_target : (str) Path to the GT images model_name : (str) name of the model model_path : (str) path of the model Visual_validation_after_training : (bool) Predict a random image after training Number_of_epochs : (int) epochs path_size : (int) patch sizes number_of_patches : (int) number of patches for each image User_Default_Advances_Parameters : (bool) Use default parameters for the training number_of_steps : (int) number of steps batch_size : (int) batch size percentage_validation : (int) percentage validation Return ------- void ''' OutputFile = Training_target + "/*.tif" InputFile = Training_source + "/*.tif" base = "/content/" training_data = base + "/my_training_data.npz" if (Use_Default_Advanced_Parameters): print("Default advanced parameters enabled") batch_size = 64 percentage_validation = 10 percentage = percentage_validation / 100 #here we check that no model with the same name already exist, if so delete if os.path.exists(model_path + '/' + model_name): shutil.rmtree(model_path + '/' + model_name) # The shape of the images. x = imread(InputFile) y = imread(OutputFile) print('Loaded Input images (number, width, length) =', x.shape) print('Loaded Output images (number, width, length) =', y.shape) print("Parameters initiated.") # RawData Object # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images. # This file is saved in .npz format and later called when loading the trainig data. raw_data = data.RawData.from_folder(basepath=base, source_dirs=[Training_source], target_dir=Training_target, axes='CYX', pattern='*.tif*') X, Y, XY_axes = data.create_patches(raw_data, patch_filter=None, patch_size=(patch_size, patch_size), n_patches_per_image=number_of_patches) print('Creating 2D training dataset') training_path = model_path + "/rawdata" rawdata1 = training_path + ".npz" np.savez(training_path, X=X, Y=Y, axes=XY_axes) # Load Training Data (X, Y), (X_val, Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] #Show_patches(X,Y) #Here we automatically define number_of_step in function of training data and batch size if (Use_Default_Advanced_Parameters): number_of_steps = int(X.shape[0] / batch_size) + 1 print(number_of_steps) #Here we create the configuration file config = Config(axes, n_channel_in, n_channel_out, probabilistic=False, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=0.0004) print(config) vars(config) # Compile the CARE model for network training model_training = CARE(config, model_name, basedir=model_path) if (Visual_validation_after_training): Cell_executed = 1 import time start = time.time() #@markdown ##Start Training # Start Training history = model_training.train(X, Y, validation_data=(X_val, Y_val)) print("Training, done.") if (Visual_validation_after_training): Predict_a_image(Training_source, Training_target, model_path, model_training) # Displaying the time elapsed for training dt = time.time() - start min, sec = divmod(dt, 60) hour, min = divmod(min, 60) print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec), "sec(s)") Show_loss_function(history, model_path)
def train(X, Y, X_val, Y_val, axes, model_name, csv_file,probabilistic, validation_split, patch_size, **kwargs): """Trains CARE model with patches previously created. CARE model parameters configurated via 'Config' object: * parameters of the underlying neural network, * the learning rate, * the number of parameter updates per epoch, * the loss function, and * whether the model is probabilistic or not. Parameters ---------- X : np.array Input data X for training Y : np.array Ground truth data Y for training X_val : np.array Input data X for validation Y_val : np.array Ground truth Y for validation axes : str Semantic order of the axis in the image train_steps_per_epoch : int Number of training steps per epochs train_epochs : int Number of training epochs model_name : str Name of the model to be saved after training Returns ------- history Object with the loss values saved """ config = Config(axes, n_channel_in=1, n_channel_out=1, probabilistic=probabilistic, allow_new_parameters=True, **kwargs) model = CARE(config, model_name, basedir='models') # # Training # [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) from the current working directory with `tensorboard --logdir=.` # Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser. history = model.train(X, Y, validation_data=(X_val, Y_val)) # df = pd.DataFrame() # df['probabilistic'] = [config.probabilistic] # df['batch_size'] = config.train_batch_size # df['train_epochs'] = config.train_epochs # df['lr'] = config.train_learning_rate # df['train_steps_per_epoch'] = config.train_steps_per_epoch # df['unet_last_activation'] = config.unet_last_activation # df['unet_n_depth'] = config.unet_n_depth # df['unet_n_first'] = config.unet_n_first # df['unet_residual'] = config.unet_residual # df['patch_size'] = str(patch_size) # df['validation_split'] = validation_split # last_value = len(history.history['loss']) - 1 # #dict = kwargs # #data = dict.update({'loss':history.history['loss'][last_value], }) # df['val_loss'] = history.history['val_loss'][last_value] # df['loss'] = history.history['loss'][last_value] # df['mse']=history.history['mse'][last_value] # df['mae']=history.history['mae'][last_value] # df['val_mse']=history.history['val_mse'][last_value] # df['val_mae']=history.history['val_mae'][last_value] # #df = pd.DataFrame.from_dict(data) # #print('File saved in:', os.getcwd()+'/'+csv_file) # df.to_csv('/data/'+csv_file) model.export_TF() return history
validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # In[3]: # In[4]: config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=4, train_epochs=50, train_steps_per_epoch=400, train_batch_size=16, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) # In[5]: model = ProjectionCARE( config, 'DrosophilaDenoisingProjection', basedir='/local/u934/private/v_kapoor/CurieDeepLearningModels')
def train(self, channels=None, **config_args): # limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels with Timer("Training"): for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / "CH_{}_training_patches.npz".format(ch), validation_split=0.1, verbose=False, ) c = axes_dict(axes)["C"] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config( axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, probabilistic=self.probabilistic, **config_args, ) # Training # if ( # pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch) # ).exists(): # print("config there already") # config = None model = CARE( config, "CH_{}_model".format(ch), basedir=pathlib.Path(self.out_dir) / "models", ) # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( " >> ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return except tf.errors.UnknownError: print( " >> UnknownError: Aborting...\n No enough memory available on GPU... are other GPU jobs running?" ) return # print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) if self.probabilistic: _P = _P[..., 0] plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle("5 example validation patches\n" "top row: input (source), " "middle row: target (ground truth), " "bottom row: predicted from source") plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")