def test_model_train(tmpdir,config): rng = np.random.RandomState(42) K.clear_session() X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,)) Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,)) model = CARE(config,basedir=str(tmpdir)) model.train(X,Y,(X,Y))
def test_model_train(): rng = np.random.RandomState(42) configs = config_generator( axes=['YX', 'ZYX'], n_channel_in=[1, 2], n_channel_out=[1, 2], probabilistic=[False, True], # unet_residual = [False,True], unet_n_depth=[1], unet_kern_size=[3], unet_n_first=[4], unet_last_activation=['linear'], # unet_input_shape = [(None, None, 1)], train_loss=['mae', 'laplace'], train_epochs=[2], train_steps_per_epoch=[2], # train_learning_rate = [0.0004], train_batch_size=[2], # train_tensorboard = [True], # train_checkpoint = ['weights_best.h5'], # train_reduce_lr = [{'factor': 0.5, 'patience': 10}], ) with tempfile.TemporaryDirectory() as tmpdir: for config in configs: K.clear_session() if config.is_valid(): X = rng.uniform(size=(4, ) + (32, ) * config.n_dim + (config.n_channel_in, )) Y = rng.uniform(size=(4, ) + (32, ) * config.n_dim + (config.n_channel_out, )) model = CARE(config, basedir=tmpdir) model.train(X, Y, (X, Y))
def train_care_generated_data(model: CARE, epochs: int, X: np.ndarray, Y: np.ndarray, X_val: np.ndarray, Y_val: np.ndarray, steps_per_epoch: int = 400) -> None: """ Train a CARE model on a dataset :param model: The CARE model to train :param epochs: The number of epochs to train for :param X: The training data input :param Y: The training data expected output :param X_val: The validation data input :param Y_val: The validation data expected output """ rearr = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3]) rearry = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] )[:, :, :, :, :1] X = rearr(X) X_val = rearr(X_val) Y = rearry(Y) Y_val = rearry(Y_val) model.train(X, Y, validation_data=(X_val, Y_val), epochs=epochs, steps_per_epoch=steps_per_epoch)
def train(self, channels=None, **config_args): #limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, **config_args) # Training model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models') # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return #print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle('5 example validation patches\n' 'top row: input (source), ' 'middle row: target (ground truth), ' 'bottom row: predicted from source') plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")
def dev(args): import json # Load and parse training data (X, Y), (X_val, Y_val), axes = load_training_data(args.train_data, validation_split=args.valid_split, axes=args.axes, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # Model config print('args.resume: ', args.resume) if args.resume: # If resuming, config=None will reload the saved config config = None print('Attempting to resume') elif args.config: print('loading config from args') config_args = json.load(open(args.config)) config = Config(**config_args) else: config = Config(axes, n_channel_in, n_channel_out, probabilistic=args.prob, train_steps_per_epoch=args.steps, train_epochs=args.epochs) print(vars(config)) # Load or init model model = CARE(config, args.model_name, basedir='models') # Training, tensorboard available history = model.train(X, Y, validation_data=(X_val, Y_val)) # Plot training results print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.savefig(args.model_name + '_training.png') # Export model to be used w/ csbdeep fiji plugins and KNIME flows model.export_TF()
def train(Training_source=".", Training_target=".", model_name="No_name", model_path=".", Visual_validation_after_training=True, number_of_epochs=100, patch_size=64, number_of_patches=10, Use_Default_Advanced_Parameters=True, number_of_steps=300, batch_size=32, percentage_validation=15): ''' Main function of the script. Train the model an save in model_path Parameters ---------- Training_source : (str) Path to the noisy images Training_target : (str) Path to the GT images model_name : (str) name of the model model_path : (str) path of the model Visual_validation_after_training : (bool) Predict a random image after training Number_of_epochs : (int) epochs path_size : (int) patch sizes number_of_patches : (int) number of patches for each image User_Default_Advances_Parameters : (bool) Use default parameters for the training number_of_steps : (int) number of steps batch_size : (int) batch size percentage_validation : (int) percentage validation Return ------- void ''' OutputFile = Training_target + "/*.tif" InputFile = Training_source + "/*.tif" base = "/content/" training_data = base + "/my_training_data.npz" if (Use_Default_Advanced_Parameters): print("Default advanced parameters enabled") batch_size = 64 percentage_validation = 10 percentage = percentage_validation / 100 #here we check that no model with the same name already exist, if so delete if os.path.exists(model_path + '/' + model_name): shutil.rmtree(model_path + '/' + model_name) # The shape of the images. x = imread(InputFile) y = imread(OutputFile) print('Loaded Input images (number, width, length) =', x.shape) print('Loaded Output images (number, width, length) =', y.shape) print("Parameters initiated.") # RawData Object # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images. # This file is saved in .npz format and later called when loading the trainig data. raw_data = data.RawData.from_folder(basepath=base, source_dirs=[Training_source], target_dir=Training_target, axes='CYX', pattern='*.tif*') X, Y, XY_axes = data.create_patches(raw_data, patch_filter=None, patch_size=(patch_size, patch_size), n_patches_per_image=number_of_patches) print('Creating 2D training dataset') training_path = model_path + "/rawdata" rawdata1 = training_path + ".npz" np.savez(training_path, X=X, Y=Y, axes=XY_axes) # Load Training Data (X, Y), (X_val, Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] #Show_patches(X,Y) #Here we automatically define number_of_step in function of training data and batch size if (Use_Default_Advanced_Parameters): number_of_steps = int(X.shape[0] / batch_size) + 1 print(number_of_steps) #Here we create the configuration file config = Config(axes, n_channel_in, n_channel_out, probabilistic=False, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=0.0004) print(config) vars(config) # Compile the CARE model for network training model_training = CARE(config, model_name, basedir=model_path) if (Visual_validation_after_training): Cell_executed = 1 import time start = time.time() #@markdown ##Start Training # Start Training history = model_training.train(X, Y, validation_data=(X_val, Y_val)) print("Training, done.") if (Visual_validation_after_training): Predict_a_image(Training_source, Training_target, model_path, model_training) # Displaying the time elapsed for training dt = time.time() - start min, sec = divmod(dt, 60) hour, min = divmod(min, 60) print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec), "sec(s)") Show_loss_function(history, model_path)
train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) # In[6]: model = CARE(config=config, name=ModelName, basedir=ModelDir) #input_weights = ModelDir + ModelName + '/' +'weights_best.h5' #model.load_weights(input_weights) # In[ ]: history = model.train(X, Y, validation_data=(X_val, Y_val)) # In[ ]: print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) # In[ ]: plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[25:30]) if config.probabilistic: _P = _P[..., :(_P.shape[-1] // 2)] plot_some(X_val[0], Y_val[0], _P, pmax=99.5)
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
is_3d = bool(args.is_3d) axes = "ZYX" if is_3d else "YX" n_dim = 3 if is_3d else 2 config = Config(axes, n_dim=n_dim, n_channel_in=1, n_channel_out=1, probabilistic=int(args.probabilistic), train_batch_size=int(args.batch_size), unet_kern_size=3) for i in range(int(args.num_models)): model = CARE(config, f"model_{i}", args.models_dir) train_history = model.train(tr_data[0], tr_data[1], validation_data=val_data, epochs=int(args.epochs), steps_per_epoch=int(args.steps_per_epoch)) exit() plot_history(train_history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) def crop_to_even(img): shape = img.shape new_shape = (shape[0] - (shape[0] % 2), shape[1] - (shape[1] % 2)) return img[0:new_shape[0], 0:new_shape[1]]
def train(X, Y, X_val, Y_val, axes, model_name, csv_file,probabilistic, validation_split, patch_size, **kwargs): """Trains CARE model with patches previously created. CARE model parameters configurated via 'Config' object: * parameters of the underlying neural network, * the learning rate, * the number of parameter updates per epoch, * the loss function, and * whether the model is probabilistic or not. Parameters ---------- X : np.array Input data X for training Y : np.array Ground truth data Y for training X_val : np.array Input data X for validation Y_val : np.array Ground truth Y for validation axes : str Semantic order of the axis in the image train_steps_per_epoch : int Number of training steps per epochs train_epochs : int Number of training epochs model_name : str Name of the model to be saved after training Returns ------- history Object with the loss values saved """ config = Config(axes, n_channel_in=1, n_channel_out=1, probabilistic=probabilistic, allow_new_parameters=True, **kwargs) model = CARE(config, model_name, basedir='models') # # Training # [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) from the current working directory with `tensorboard --logdir=.` # Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser. history = model.train(X, Y, validation_data=(X_val, Y_val)) # df = pd.DataFrame() # df['probabilistic'] = [config.probabilistic] # df['batch_size'] = config.train_batch_size # df['train_epochs'] = config.train_epochs # df['lr'] = config.train_learning_rate # df['train_steps_per_epoch'] = config.train_steps_per_epoch # df['unet_last_activation'] = config.unet_last_activation # df['unet_n_depth'] = config.unet_n_depth # df['unet_n_first'] = config.unet_n_first # df['unet_residual'] = config.unet_residual # df['patch_size'] = str(patch_size) # df['validation_split'] = validation_split # last_value = len(history.history['loss']) - 1 # #dict = kwargs # #data = dict.update({'loss':history.history['loss'][last_value], }) # df['val_loss'] = history.history['val_loss'][last_value] # df['loss'] = history.history['loss'][last_value] # df['mse']=history.history['mse'][last_value] # df['mae']=history.history['mae'][last_value] # df['val_mse']=history.history['val_mse'][last_value] # df['val_mae']=history.history['val_mae'][last_value] # #df = pd.DataFrame.from_dict(data) # #print('File saved in:', os.getcwd()+'/'+csv_file) # df.to_csv('/data/'+csv_file) model.export_TF() return history
def n2v_flim(project, n2v_num_pix=32): results_file = os.path.join(project, 'fit_results.hdf5') X, groups, mask = extract_results(results_file) data_shape = np.shape(X) print(data_shape) mean, std = np.mean(X), np.std(X) X = normalize(X, mean, std) XA = X #augment_data(X) X_val = X[0:10,...] # We concatenate an extra channel filled with zeros. It will be internally used for the masking. Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1) Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) n_x = X.shape[1] n_chan = X.shape[-1] manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x)) # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. config = Config('SYXC', n_channel_in=n_chan, n_channel_out=n_chan, unet_kern_size = 5, unet_n_depth = 2, train_steps_per_epoch=200, train_loss='mae', train_epochs=35, batch_norm = False, train_scheme = 'Noise2Void', train_batch_size = 128, n2v_num_pix = n2v_num_pix, n2v_patch_shape = (n2v_num_pix, n2v_num_pix), n2v_manipulator = 'uniform_withCP', n2v_neighborhood_radius='5') vars(config) model = CARE(config, 'n2v_model', basedir=project) history = model.train(XA, Y, validation_data=(X_val,Y_val)) model.load_weights(name='weights_best.h5') output_project = project.replace('.flimfit','-n2v.flimfit') if os.path.exists(output_project) : shutil.rmtree(output_project) shutil.copytree(project, output_project) output_file = os.path.join(output_project, 'fit_results.hdf5') X_pred = np.zeros(X.shape) for i in range(X.shape[0]): X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std) X_pred[mask] = np.NaN insert_results(output_file, X_pred, groups)
x_train, x_test = x_train[:, :, :, keep_idx], x_test[:, :, :, keep_idx] y_train, y_test = y_train[:, :, :, keep_idx], y_test[:, :, :, keep_idx] # Taken from CARE FAQ, runs the model axes = 'SYXC' c = axes_dict(axes)['C'] n_channel_in, n_channel_out = x_train.shape[c], y_train.shape[c] config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_epochs=num_epochs, train_steps_per_epoch=30) model = CARE(config, model_name, basedir='models') history = model.train(x_train, y_train, validation_data=(x_test, y_test)) fig = plt.figure(figsize=(30, 30)) _P = model.keras_model.predict(x_test[:5, :, :, :]) _P_mean = _P[..., :(_P.shape[-1] // 2)] _P_scale = _P[..., (_P.shape[-1] // 2):] plot_some(x_test[:5, :, :, 0], y_test[:5, :, :, :], _P_mean, _P_scale, pmax=99.5) fig.suptitle('5 example validation patches\n' 'first row: input (source), ' 'second row: target (ground truth), ' 'third row: predicted Laplace mean, ' 'forth row: predicted Laplace scale')
def train(self, channels=None, **config_args): # limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels with Timer("Training"): for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / "CH_{}_training_patches.npz".format(ch), validation_split=0.1, verbose=False, ) c = axes_dict(axes)["C"] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config( axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, probabilistic=self.probabilistic, **config_args, ) # Training # if ( # pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch) # ).exists(): # print("config there already") # config = None model = CARE( config, "CH_{}_model".format(ch), basedir=pathlib.Path(self.out_dir) / "models", ) # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( " >> ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return except tf.errors.UnknownError: print( " >> UnknownError: Aborting...\n No enough memory available on GPU... are other GPU jobs running?" ) return # print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) if self.probabilistic: _P = _P[..., 0] plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle("5 example validation patches\n" "top row: input (source), " "middle row: target (ground truth), " "bottom row: predicted from source") plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")