def test_model_train(): rng = np.random.RandomState(42) configs = config_generator( axes=['YX', 'ZYX'], n_channel_in=[1, 2], n_channel_out=[1, 2], probabilistic=[False, True], # unet_residual = [False,True], unet_n_depth=[1], unet_kern_size=[3], unet_n_first=[4], unet_last_activation=['linear'], # unet_input_shape = [(None, None, 1)], train_loss=['mae', 'laplace'], train_epochs=[2], train_steps_per_epoch=[2], # train_learning_rate = [0.0004], train_batch_size=[2], # train_tensorboard = [True], # train_checkpoint = ['weights_best.h5'], # train_reduce_lr = [{'factor': 0.5, 'patience': 10}], ) with tempfile.TemporaryDirectory() as tmpdir: for config in configs: K.clear_session() if config.is_valid(): X = rng.uniform(size=(4, ) + (32, ) * config.n_dim + (config.n_channel_in, )) Y = rng.uniform(size=(4, ) + (32, ) * config.n_dim + (config.n_channel_out, )) model = CARE(config, basedir=tmpdir) model.train(X, Y, (X, Y))
def train_care_generated_data(model: CARE, epochs: int, X: np.ndarray, Y: np.ndarray, X_val: np.ndarray, Y_val: np.ndarray, steps_per_epoch: int = 400) -> None: """ Train a CARE model on a dataset :param model: The CARE model to train :param epochs: The number of epochs to train for :param X: The training data input :param Y: The training data expected output :param X_val: The validation data input :param Y_val: The validation data expected output """ rearr = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3]) rearry = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3] )[:, :, :, :, :1] X = rearr(X) X_val = rearr(X_val) Y = rearry(Y) Y_val = rearry(Y_val) model.train(X, Y, validation_data=(X_val, Y_val), epochs=epochs, steps_per_epoch=steps_per_epoch)
def test_model_train(tmpdir,config): rng = np.random.RandomState(42) K.clear_session() X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,)) Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,)) model = CARE(config,basedir=str(tmpdir)) model.train(X,Y,(X,Y))
def run_z(y, x, models): # Prep data # model prediction and normalizations output float64 x = np.array(x, dtype='float64') y = np.array(y, dtype='float64') axes = 'ZYX' # Prepare the dict for metrics d = { 'output': [], 'columns': ['output', 'rmse', 'ssim'], 'id_vars': ['output'], 'var_name': 'metric' } # Define comparisons def get_output(name, y, x): return [name, np.sqrt(mse(x, y)), ssim(x, y)] # Normalize GT dynamic range to enable comparisons w/ numbers yn = util.percentile_norm(y, axes) # Get the comparison for normalized input im xn = util.percentile_norm(x, axes) d['output'].append(get_output('input', yn, xn)) predictions = {'models': ['input', 'N(GT)'], 'ims': [xn, yn]} for m in models: model = CARE(config=None, name=m, basedir='models') restored = model.predict_probabilistic(x, axes, n_tiles=(1, 4, 4)) pred = restored.mean() y_pred_n = util.percentile_norm(pred, axes) d['output'].append(get_output(m, yn, y_pred_n)) predictions['models'].append(m) predictions['ims'].append(y_pred_n) # Plot a random stack zix = util.get_randint_ixs(1, len(y)) ims = [[im[zix] for im in predictions['ims']]] plt.figure(figsize=(16, 10)) plot_some(np.stack(ims), title_list=[predictions['models']]) plt.show() # Costruct a df for the barplot df = pd.DataFrame( d['output'], columns=d['columns'], ) df = pd.melt(df, id_vars=d['id_vars'], var_name=d['var_name']) g = sns.catplot(x='metric', y='value', hue='output', kind='bar', sharey=False, data=df) g.ax.set_ylim(0, 1) plt.show()
def train(self, channels=None, **config_args): #limit_gpu_memory(fraction=1) if channels is None: channels = self.train_channels for ch in channels: print("-- Training channel {}...".format(ch)) (X, Y), (X_val, Y_val), axes = load_training_data( self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs, train_steps_per_epoch=self.train_steps_per_epoch, train_batch_size=self.train_batch_size, **config_args) # Training model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models') # Show learning curve and example validation results try: history = model.train(X, Y, validation_data=(X_val, Y_val)) except tf.errors.ResourceExhaustedError: print( "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?" ) return #print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.figure(figsize=(12, 7)) _P = model.keras_model.predict(X_val[:5]) plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray") plt.suptitle('5 example validation patches\n' 'top row: input (source), ' 'middle row: target (ground truth), ' 'bottom row: predicted from source') plt.show() print("-- Export model for use in Fiji...") model.export_TF() print("Done")
def gen_care_single_model(name: str) -> CARE: """ Generate a single channel CARE model or retrieve the one that already exists under the name specified :param name: The name of the model :return: The CARE model """ try: model = CARE(None, name) except FileNotFoundError: config = Config('xyzc', n_channel_in=1, n_channel_out=1) model = CARE(config, name) return model
def dev(args): import json # Load and parse training data (X, Y), (X_val, Y_val), axes = load_training_data(args.train_data, validation_split=args.valid_split, axes=args.axes, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # Model config print('args.resume: ', args.resume) if args.resume: # If resuming, config=None will reload the saved config config = None print('Attempting to resume') elif args.config: print('loading config from args') config_args = json.load(open(args.config)) config = Config(**config_args) else: config = Config(axes, n_channel_in, n_channel_out, probabilistic=args.prob, train_steps_per_epoch=args.steps, train_epochs=args.epochs) print(vars(config)) # Load or init model model = CARE(config, args.model_name, basedir='models') # Training, tensorboard available history = model.train(X, Y, validation_data=(X_val, Y_val)) # Plot training results print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) plt.savefig(args.model_name + '_training.png') # Export model to be used w/ csbdeep fiji plugins and KNIME flows model.export_TF()
def load_data_and_model(npz_data, model_name, valid_split=0.2, axes='SCZXY'): model = CARE(config=None, name=model_name, basedir='models') X_val, Y_val = load_training_data(npz_data, validation_split=valid_split, axes=axes, verbose=True)[1] return X_val, Y_val, model
def test_model_predict_tiled(tmpdir,config): """ Test that tiled prediction yields the same or similar result as compared to predicting the whole image at once. """ rng = np.random.RandomState(42) normalizer, resizer = NoNormalizer(), NoResizer() K.clear_session() model = CARE(config,basedir=str(tmpdir)) def _predict(imdims,axes,n_tiles): img = rng.uniform(size=imdims) # print(img.shape, axes) mean, scale = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=None) mean_tiled, scale_tiled = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=n_tiles) assert mean.shape == mean_tiled.shape if config.probabilistic: assert scale.shape == scale_tiled.shape error_max = np.max(np.abs(mean-mean_tiled)) # print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max)) assert error_max < 1e-3 return mean, mean_tiled imdims = list(rng.randint(50,70,size=config.n_dim)) if config.n_dim == 3: imdims[0] = 16 # make one dim small, otherwise test takes too long div_n = 2**config.unet_n_depth imdims = [(d//div_n)*div_n for d in imdims] imdims.insert(0,config.n_channel_in) axes = 'C'+config.axes.replace('C','') for n_tiles in ( -1, 1.2, [1]+[1.2]*config.n_dim, [1]*config.n_dim, # missing value for channel axis [2]+[1]*config.n_dim, # >1 tiles for channel axis ): with pytest.raises(ValueError): _predict(imdims,axes,n_tiles) for n_tiles in [list(rng.randint(1,5,size=config.n_dim)) for _ in range(3)]: # print(imdims,axes,[1]+n_tiles) if config.n_channel_in == 1: _predict(imdims[1:],axes[1:],n_tiles) _predict(imdims,axes,[1]+n_tiles) # legacy api: tile only largest dimension n_blocks = np.max(imdims) // div_n for n_tiles in (2,5,n_blocks+1): with pytest.warns(UserWarning): if config.n_channel_in == 1: _predict(imdims[1:],axes[1:],n_tiles) _predict(imdims,axes,n_tiles)
def test_model_predict(): rng = np.random.RandomState(42) configs = config_generator( axes=['YX', 'ZYX'], n_channel_in=[1, 2], n_channel_out=[1, 2], probabilistic=[False, True], # unet_residual = [False,True], unet_n_depth=[2], unet_kern_size=[3], unet_n_first=[4], unet_last_activation=['linear'], # unet_input_shape = [(None, None, 1)], ) with tempfile.TemporaryDirectory() as tmpdir: normalizer, resizer = NoNormalizer(), NoResizer() for config in filter(lambda c: c.is_valid(), configs): K.clear_session() model = CARE(config, basedir=tmpdir) axes = config.axes def _predict(imdims, axes): img = rng.uniform(size=imdims) # print(img.shape, axes, config.n_channel_out) mean, scale = model._predict_mean_and_scale( img, axes, normalizer, resizer) if config.probabilistic: assert mean.shape == scale.shape else: assert scale is None if 'C' not in axes: if config.n_channel_out == 1: assert mean.shape == img.shape else: assert mean.shape == img.shape + ( config.n_channel_out, ) else: channel = axes_dict(axes)['C'] imdims[channel] = config.n_channel_out assert mean.shape == tuple(imdims) imdims = list(rng.randint(20, 40, size=config.n_dim)) div_n = 2**config.unet_n_depth imdims = [(d // div_n) * div_n for d in imdims] if config.n_channel_in == 1: _predict(imdims, axes=axes.replace('C', '')) channel = rng.randint(0, config.n_dim) imdims.insert(channel, config.n_channel_in) _axes = axes.replace('C', '') _axes = _axes[:channel] + 'C' + _axes[channel:] _predict(imdims, axes=_axes)
def gen_care_dual_model(name: str, batch_size: int = 16, **kwargs): """ Generate a dual channel CARE model or retrieve the one that already exists under the name specified :param name: The name of the model :param batch_size: The training batch size to use (only used if the model doesn't exist yet) :param kwargs: Parameters to pass to the model constructor (only used if the model doesn't exist yet :return: The CARE model """ try: model = CARE(None, name) except FileNotFoundError: config = Config('xyzc', n_channel_in=2, n_channel_out=1, train_batch_size=batch_size, **kwargs) model = CARE(config, name) return model
def _run_multi(y, x, models, d, ix, args): """ Trying an alternative normalization to see if that's causing the incong. with the paper results """ # Prep data axes = 'ZYX' def get_output(name, y, x): # Define comparisons return [name, np.sqrt(mse(x, y)), ssim(x, y)] # Normalize GT dynamic range to enable comparisons w/ numbers yn = normalize(y, pmin=0.1, pmax=99.9) # Get the comparison for normalized input im if args.rescale_x: xn = normalize_minmse(x, yn) else: xn = normalize(x) d['output'].append(get_output('input', yn, xn)) for m in models: model = CARE(config=None, name=m, basedir='models') pred = model.predict(xn, axes, n_tiles=(1, 4, 4), normalizer=None) pred = normalize_minmse(pred, yn) d['output'].append(get_output(m, yn, pred)) # Save the first test volume for each model if args.save_predictions: if ix == 0: save_prediction(pred, m, d['condition']) # Save the first test volume input and GT if args.save_predictions: if ix == 0: save_prediction(xn, 'input', d['condition']) save_prediction(yn, None, d['condition']) return d
def _run_multi(y, x, models, d): # Prep data axes = 'ZYX' # Define comparisons def get_output(name, y, x): return [name, np.sqrt(mse(x, y)), ssim(x, y)] # Normalize GT dynamic range to enable comparisons w/ numbers yn = normalize(y, pmin=0.1, pmax=99.9) # Get the comparison for normalized input im xn = normalize(x) d['output'].append(get_output('input', yn, xn)) for m in models: model = CARE(config=None, name=m, basedir='models') # None normalizer, already normalizing x, and this way can report the # exact pmin/max params used pred = model.predict(xn, n_tiles=(1, 4, 4), axes=axes, normalizer=None) pred = normalize_minmse(pred, yn) d['output'].append(get_output(m, yn, pred)) return d
def tst_care(care_model_path: Path, lfd_path: Path, overwrite: Optional[bool] = False): assert lfd_path.name == "pred" assert lfd_path.parent.name == "lfd" from csbdeep.models import CARE axes = "ZYX" model = CARE(config=None, name=care_model_path.name, basedir=care_model_path.parent) care_result_root = lfd_path.parent.parent / "care" / "pred" care_result_root.mkdir(parents=True, exist_ok=True) print("saving CARE reconstructions to", care_result_root) for file_path in tqdm(list(lfd_path.glob("*tif"))): result_path = care_result_root / file_path.name if result_path.exists() and not overwrite: continue x = imread(str(file_path)).squeeze() restored = model.predict(x, axes) restored = restored.squeeze()[None, ...] save_tensor(result_path, restored)
def _build(): with pytest.raises(FileNotFoundError): CARE(None, basedir=str(tmpdir)) CARE(config, name='model', basedir=None) with pytest.raises(ValueError): CARE(None, basedir=None) CARE(config, basedir=str(tmpdir)).export_TF() with pytest.warns(UserWarning): CARE(config, name='model', basedir=str(tmpdir)) CARE(config, name='model', basedir=str(tmpdir)) CARE(None, name='model', basedir=str(tmpdir))
def test_model_predict(tmpdir,config): rng = np.random.RandomState(42) normalizer, resizer = NoNormalizer(), NoResizer() K.clear_session() model = CARE(config,basedir=str(tmpdir)) axes = config.axes def _predict(imdims,axes): img = rng.uniform(size=imdims) # print(img.shape, axes, config.n_channel_out) if config.probabilistic: prob = model.predict_probabilistic(img, axes, normalizer, resizer) mean, scale = prob.mean(), prob.scale() assert mean.shape == scale.shape else: mean = model.predict(img, axes, normalizer, resizer) if 'C' not in axes: if config.n_channel_out == 1: assert mean.shape == img.shape else: assert mean.shape == img.shape + (config.n_channel_out,) else: channel = axes_dict(axes)['C'] imdims[channel] = config.n_channel_out assert mean.shape == tuple(imdims) imdims = list(rng.randint(20,40,size=config.n_dim)) div_n = 2**config.unet_n_depth imdims = [(d//div_n)*div_n for d in imdims] if config.n_channel_in == 1: _predict(imdims,axes=axes.replace('C','')) channel = rng.randint(0,config.n_dim) imdims.insert(channel,config.n_channel_in) _axes = axes.replace('C','') _axes = _axes[:channel]+'C'+_axes[channel:] _predict(imdims,axes=_axes)
def predict(path_data, model_name, n_tiles, axes, plot_prediction, stack_nb, filter_data, folder_name_save): """Predicts the output of the netowrk using a pre-trained model. Parameters ---------- path_data : str Path to input data to predict model_name : str Name of the pre-trained model n_tiles : tuple Size of tile to split the patches (helps avoid out of memory problems when predicting axes : str Semantic order of the channels plot_prediction : bool True or False whether the prediction is plot or not """ sep = '/' model_path = sep.join(model_name.split(sep)[:-1]) #model_path = '/models/' print('-------', model_path) model_name = model_name.split(sep)[-1] model = CARE(config=None, name=model_name, basedir=model_path) for file_ in sorted(os.listdir(path_data)): if file_.endswith('.tif') and not pathlib.Path( os.path.dirname(os.getcwd()) + '/predicted/' + file_).exists(): if filter_data in file_: reconstruction(model, file_, path_data, axes, n_tiles, plot_prediction, folder_name_save) elif filter_data == 'all': reconstruction(model, file_, path_data, axes, n_tiles, plot_prediction, folder_name_save)
from tifffile import imread from csbdeep.utils import Path, download_and_extract_zip_file, plot_some from csbdeep.io import save_tiff_imagej_compatible from csbdeep.models import CARE # In[3]: basedirLow = '/local/u934/private/v_kapoor/ProjectionTraining/MasterLow/VeryLow/' basedirResults3D = '/local/u934/private/v_kapoor/ProjectionTraining/MasterLow/NotsoLow/' ModelName = 'BorialisS1S2FlorisMidNoiseModel' BaseDir = '/data/u934/service_imagerie/v_kapoor/CurieDeepLearningModels/' Path(basedirResults3D).mkdir(exist_ok=True) # In[4]: model = CARE(config=None, name=ModelName, basedir=BaseDir) # In[6]: Raw_path = os.path.join(basedirLow, '*tif') axes = 'ZYX' smallaxes = 'YX' filesRaw = glob.glob(Raw_path) filesRaw.sort print(len(filesRaw)) for fname in filesRaw: x = imread(fname) print(x.shape)
basedir = '/run/user/1000/gvfs/smb-share:server=isiserver.curie.net,share=u934/equipe_bellaiche/m_gracia/20210316/partie1' basedirResults3D = basedir + '/Restored' basedirResults2D = basedir + '/Projected' basedirResults3Dextended = basedirResults3D + '/Restored' basedirResults2Dextended = basedirResults2D + '/Projected' Model_Dir = '/run/media/sancere/DATA/Lucas_Model_to_use/CARE/' # In[3]: RestorationModel = 'CARE_restoration_Borealis_Bin1' ProjectionModel = 'CARE_projection_Borealis_Bin1' RestorationModel = CARE(config=None, name=RestorationModel, basedir=Model_Dir) ProjectionModel = ProjectionCARE(config=None, name=ProjectionModel, basedir=Model_Dir) # In[5]: Path(basedirResults3D).mkdir(exist_ok=True) Path(basedirResults2D).mkdir(exist_ok=True) Raw_path = os.path.join(basedir, '*TIF') #tif or TIF be careful axes = 'ZYX' #projection axes : 'YX' filesRaw = glob.glob(Raw_path)
def train(Training_source=".", Training_target=".", model_name="No_name", model_path=".", Visual_validation_after_training=True, number_of_epochs=100, patch_size=64, number_of_patches=10, Use_Default_Advanced_Parameters=True, number_of_steps=300, batch_size=32, percentage_validation=15): ''' Main function of the script. Train the model an save in model_path Parameters ---------- Training_source : (str) Path to the noisy images Training_target : (str) Path to the GT images model_name : (str) name of the model model_path : (str) path of the model Visual_validation_after_training : (bool) Predict a random image after training Number_of_epochs : (int) epochs path_size : (int) patch sizes number_of_patches : (int) number of patches for each image User_Default_Advances_Parameters : (bool) Use default parameters for the training number_of_steps : (int) number of steps batch_size : (int) batch size percentage_validation : (int) percentage validation Return ------- void ''' OutputFile = Training_target + "/*.tif" InputFile = Training_source + "/*.tif" base = "/content/" training_data = base + "/my_training_data.npz" if (Use_Default_Advanced_Parameters): print("Default advanced parameters enabled") batch_size = 64 percentage_validation = 10 percentage = percentage_validation / 100 #here we check that no model with the same name already exist, if so delete if os.path.exists(model_path + '/' + model_name): shutil.rmtree(model_path + '/' + model_name) # The shape of the images. x = imread(InputFile) y = imread(OutputFile) print('Loaded Input images (number, width, length) =', x.shape) print('Loaded Output images (number, width, length) =', y.shape) print("Parameters initiated.") # RawData Object # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images. # This file is saved in .npz format and later called when loading the trainig data. raw_data = data.RawData.from_folder(basepath=base, source_dirs=[Training_source], target_dir=Training_target, axes='CYX', pattern='*.tif*') X, Y, XY_axes = data.create_patches(raw_data, patch_filter=None, patch_size=(patch_size, patch_size), n_patches_per_image=number_of_patches) print('Creating 2D training dataset') training_path = model_path + "/rawdata" rawdata1 = training_path + ".npz" np.savez(training_path, X=X, Y=Y, axes=XY_axes) # Load Training Data (X, Y), (X_val, Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] #Show_patches(X,Y) #Here we automatically define number_of_step in function of training data and batch size if (Use_Default_Advanced_Parameters): number_of_steps = int(X.shape[0] / batch_size) + 1 print(number_of_steps) #Here we create the configuration file config = Config(axes, n_channel_in, n_channel_out, probabilistic=False, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=0.0004) print(config) vars(config) # Compile the CARE model for network training model_training = CARE(config, model_name, basedir=model_path) if (Visual_validation_after_training): Cell_executed = 1 import time start = time.time() #@markdown ##Start Training # Start Training history = model_training.train(X, Y, validation_data=(X_val, Y_val)) print("Training, done.") if (Visual_validation_after_training): Predict_a_image(Training_source, Training_target, model_path, model_training) # Displaying the time elapsed for training dt = time.time() - start min, sec = divmod(dt, 60) hour, min = divmod(min, 60) print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec), "sec(s)") Show_loss_function(history, model_path)
from skimage.io import ImageCollection, imread, imsave from skimage import img_as_float, img_as_ubyte args = ArgumentParser() args.add_argument("--base_dir") args.add_argument("--model") args.add_argument("--out_dir") args.add_argument("--data") #args.add_argument("base_dir") #args.add_argument("model") #args.add_argument("out_dir") #args.add_argument("data") args.add_argument("--is_3d", default=False) args = args.parse_args() model = CARE(config=None, name=args.model, basedir=args.base_dir) #data = ImageCollection("training_data/val/low_snr_extracted_z/*.tif") data = ImageCollection(args.data) axes = "ZYX" if bool(args.is_3d) else "YX" if not exists(args.out_dir): makedirs(args.out_dir) for i in range(len(data)): im = img_as_float(data[i]) r = model.predict(im, axes) #r = (r - r.min()) / (r.max() - r.min()) r = img_as_ubyte(r) imsave(join(args.out_dir, f"{args.model}_{basename(data.files[i])}"), r)
#val_data = (val_data["X"], val_data["Y"]) #val_axes = tr_axes # we assume that both training and validation data are saved in the same format is_3d = bool(args.is_3d) axes = "ZYX" if is_3d else "YX" n_dim = 3 if is_3d else 2 config = Config(axes, n_dim=n_dim, n_channel_in=1, n_channel_out=1, probabilistic=int(args.probabilistic), train_batch_size=int(args.batch_size), unet_kern_size=3) for i in range(int(args.num_models)): model = CARE(config, f"model_{i}", args.models_dir) train_history = model.train(tr_data[0], tr_data[1], validation_data=val_data, epochs=int(args.epochs), steps_per_epoch=int(args.steps_per_epoch)) exit() plot_history(train_history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) def crop_to_even(img): shape = img.shape new_shape = (shape[0] - (shape[0] % 2), shape[1] - (shape[1] % 2))
f.write('{:.2f}\t{:.2f}\t{:.2f}\n'.format(x0 + dlx, y0 + dly, z0 + dlz)) def load_position_log(log_fname): try: xyz_array = np.loadtxt(log_fname, skiprows=3, delimiter='\t') print(xyz_array.shape) except: print('Unable to open ', log_fname) xyz_array = np.zeros((0, 3)) if xyz_array.shape == (3, ): xyz_array = np.zeros((0, 3)) return xyz_array model = CARE(config=None, name='bactn-gfp', basedir='CARE_Models') axes = 'ZYX' # fig, ax = pl.subplots() # pl.ion() # pl.show() print('Ready to run') while 1 != 0: log_list = [f for f in os.listdir('.') if ('.LOG' in f)] #print(log_list) for logfile in log_list: try: dirlog = open(logfile, 'r') except: break success = True
def test_model_predict_tiled(): """ Test that tiled prediction yields the same or similar result as compared to predicting the whole image at once. """ rng = np.random.RandomState(42) configs = config_generator( axes=['YX', 'ZYX'], n_channel_in=[1], n_channel_out=[1], probabilistic=[False], # unet_residual = [False,True], unet_n_depth=[1, 2, 3], unet_kern_size=[3, 5], unet_n_first=[4], unet_last_activation=['linear'], # unet_input_shape = [(None, None, 1)], ) with tempfile.TemporaryDirectory() as tmpdir: normalizer, resizer = NoNormalizer(), NoResizer() for config in filter(lambda c: c.is_valid(), configs): K.clear_session() model = CARE(config, basedir=tmpdir) def _predict(imdims, axes, n_tiles): img = rng.uniform(size=imdims) # print(img.shape, axes) mean, scale = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=1) mean_tiled, scale_tiled = model._predict_mean_and_scale( img, axes, normalizer, resizer, n_tiles=n_tiles) assert mean.shape == mean_tiled.shape if config.probabilistic: assert scale.shape == scale_tiled.shape error_max = np.max(np.abs(mean - mean_tiled)) # print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max)) assert error_max < 1e-3 return mean, mean_tiled imdims = list(rng.randint(100, 130, size=config.n_dim)) if config.n_dim == 3: imdims[ 0] = 32 # make one dim small, otherwise test takes too long div_n = 2**config.unet_n_depth imdims = [(d // div_n) * div_n for d in imdims] n_blocks = np.max(imdims) // div_n def _predict_wrapped(imdims, axes, n_tiles): if 0 < n_tiles <= n_blocks: _predict(imdims, axes, n_tiles=n_tiles) else: with pytest.warns(UserWarning): _predict(imdims, axes, n_tiles=n_tiles) imdims.insert(0, config.n_channel_in) axes = config.axes.replace('C', '') # return _predict(imdims,'C'+axes,n_tiles=(3,4)) # tile one dimension for n_tiles in (0, 2, 3, 6, n_blocks + 1): if config.n_channel_in == 1: _predict_wrapped(imdims[1:], axes, n_tiles) _predict_wrapped(imdims, 'C' + axes, n_tiles) # tile two dimensions for n_tiles in product((2, 4), (3, 5)): _predict(imdims, 'C' + axes, n_tiles) # tile three dimensions if config.n_dim == 3: _predict(imdims, 'C' + axes, (2, 3, 4))
def _build(): CARE(config, basedir=tmpdir)
def Train(self): BinaryName = 'BinaryMask/' RealName = 'RealMask/' Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif')) Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True) Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True) RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif')) ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif')) ValRealMask = sorted( glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif')) print('Instance segmentation masks:', len(RealMask)) if len(RealMask) == 0: print('Making labels') Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) for fname in Mask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = label(image) imwrite((self.BaseDir + '/' + RealName + Name + '.tif'), Binaryimage.astype('uint16')) Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif')) print('Semantic segmentation masks:', len(Mask)) if len(Mask) == 0: print('Generating Binary images') RealfilesMask = sorted( glob.glob(self.BaseDir + '/' + RealName + '*tif')) for fname in RealfilesMask: image = imread(fname) Name = os.path.basename(os.path.splitext(fname)[0]) Binaryimage = image > 0 imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'), Binaryimage.astype('uint16')) if self.GenerateNPZ: raw_data = RawData.from_folder( basepath=self.BaseDir, source_dirs=['Raw/'], target_dir='BinaryMask/', axes='ZYX', ) X, Y, XY_axes = create_patches( raw_data=raw_data, patch_size=(self.PatchZ, self.PatchY, self.PatchX), n_patches_per_image=self.n_patches_per_image, save_file=self.BaseDir + self.NPZfilename + '.npz', ) # Training UNET model if self.TrainUNET: print('Training UNET model') load_path = self.BaseDir + self.NPZfilename + '.npz' (X, Y), (X_val, Y_val), axes = load_training_data(load_path, validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] config = Config(axes, n_channel_in, n_channel_out, unet_n_depth=self.depth, train_epochs=self.epochs, train_batch_size=self.batch_size, unet_n_first=self.startfilter, train_loss='mse', unet_kern_size=self.kern_size, train_learning_rate=self.learning_rate, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) model = CARE(config, name='UNET' + self.model_name, basedir=self.model_dir) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') model.load_weights(self.copy_model_dir + 'UNET' + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') model.load_weights(self.model_dir + 'UNET' + self.model_name + '/' + 'weights_best.h5') history = model.train(X, Y, validation_data=(X_val, Y_val)) print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) if self.TrainSTAR: print('Training StarDistModel model with', self.backbone, 'backbone') self.axis_norm = (0, 1, 2) if self.CroppedLoad == False: assert len(Raw) > 1, "not enough training data" print(len(Raw)) rng = np.random.RandomState(42) ind = rng.permutation(len(Raw)) X_train = list(map(ReadFloat, Raw)) Y_train = list(map(ReadInt, RealMask)) self.Y = [ label(DownsampleData(y, self.DownsampleFactor)) for y in tqdm(Y_train) ] self.X = [ normalize(DownsampleData(x, self.DownsampleFactor), 1, 99.8, axis=self.axis_norm) for x in tqdm(X_train) ] n_val = max(1, int(round(0.15 * len(ind)))) ind_train, ind_val = ind[:-n_val], ind[-n_val:] self.X_val, self.Y_val = [self.X[i] for i in ind_val ], [self.Y[i] for i in ind_val] self.X_trn, self.Y_trn = [self.X[i] for i in ind_train ], [self.Y[i] for i in ind_train] print('number of images: %3d' % len(self.X)) print('- training: %3d' % len(self.X_trn)) print('- validation: %3d' % len(self.X_val)) if self.CroppedLoad: self.X_trn = self.DataSequencer(Raw, self.axis_norm, Normalize=True, labelMe=False) self.Y_trn = self.DataSequencer(RealMask, self.axis_norm, Normalize=False, labelMe=True) self.X_val = self.DataSequencer(ValRaw, self.axis_norm, Normalize=True, labelMe=False) self.Y_val = self.DataSequencer(ValRealMask, self.axis_norm, Normalize=False, labelMe=True) self.train_sample_cache = False print(Config3D.__doc__) anisotropy = (1, 1, 1) rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy) if self.backbone == 'resnet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, resnet_n_blocks=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', resnet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, resnet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1) if self.backbone == 'unet': conf = Config3D( rays=rays, anisotropy=anisotropy, backbone=self.backbone, train_epochs=self.epochs, train_learning_rate=self.learning_rate, unet_n_depth=self.depth, train_checkpoint=self.model_dir + self.model_name + '.h5', unet_kernel_size=(self.kern_size, self.kern_size, self.kern_size), train_patch_size=(self.PatchZ, self.PatchX, self.PatchY), train_batch_size=self.batch_size, unet_n_filter_base=self.startfilter, train_dist_loss='mse', grid=(1, 1, 1), use_gpu=self.use_gpu, n_channel_in=1, train_sample_cache=False) print(conf) vars(conf) Starmodel = StarDist3D(conf, name=self.model_name, basedir=self.model_dir) print( Starmodel._axes_tile_overlap('ZYX'), os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5')) if self.copy_model_dir is not None: if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_now.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_now.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_last.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_last.h5') if os.path.exists(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') and os.path.exists( self.model_dir + self.model_name + '/' + 'weights_best.h5') == False: print('Loading copy model') Starmodel.load_weights(self.copy_model_dir + self.copy_model_name + '/' + 'weights_best.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_now.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_now.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_last.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_last.h5') if os.path.exists(self.model_dir + self.model_name + '/' + 'weights_best.h5'): print('Loading checkpoint model') Starmodel.load_weights(self.model_dir + self.model_name + '/' + 'weights_best.h5') historyStar = Starmodel.train(self.X_trn, self.Y_trn, validation_data=(self.X_val, self.Y_val), epochs=self.epochs) print(sorted(list(historyStar.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(historyStar, ['loss', 'val_loss'], [ 'dist_relevant_mae', 'val_dist_relevant_mae', 'dist_relevant_mse', 'val_dist_relevant_mse' ])
def main(): if not ('__file__' in locals() or '__file__' in globals()): print('running interactively, exiting.') sys.exit(0) # parse arguments parser, args = parse_args() args_dict = vars(args) # exit and show help if no arguments provided at all if len(sys.argv) == 1: parser.print_help() sys.exit(0) # check for required arguments manually (because of argparse issue) required = ('--input-dir', '--input-axes', '--norm-pmin', '--norm-pmax', '--model-basedir', '--model-name', '--output-dir') for r in required: dest = r[2:].replace('-', '_') if args_dict[dest] is None: parser.print_usage(file=sys.stderr) print("%s: error: the following arguments are required: %s" % (parser.prog, r), file=sys.stderr) sys.exit(1) # show effective arguments (including defaults) if not args.quiet: print('Arguments') print('---------') pprint(args_dict) print() sys.stdout.flush() # logging function log = (lambda *a, **k: None) if args.quiet else tqdm.write # get list of input files and exit if there are none file_list = list(Path(args.input_dir).glob(args.input_pattern)) if len(file_list) == 0: log("No files to process in '%s' with pattern '%s'." % (args.input_dir, args.input_pattern)) sys.exit(0) # delay imports after checking to all required arguments are provided from tifffile import imread, imsave from csbdeep.utils.tf import keras_import K = keras_import('backend') from csbdeep.models import CARE from csbdeep.data import PercentileNormalizer sys.stdout.flush() sys.stderr.flush() # limit gpu memory if args.gpu_memory_limit is not None: from csbdeep.utils.tf import limit_gpu_memory limit_gpu_memory(args.gpu_memory_limit) # create CARE model and load weights, create normalizer K.clear_session() model = CARE(config=None, name=args.model_name, basedir=args.model_basedir) if args.model_weights is not None: print("Loading network weights from '%s'." % args.model_weights) model.load_weights(args.model_weights) normalizer = PercentileNormalizer(pmin=args.norm_pmin, pmax=args.norm_pmax, do_after=args.norm_undo) n_tiles = args.n_tiles if n_tiles is not None and len(n_tiles) == 1: n_tiles = n_tiles[0] processed = [] # process all files for file_in in tqdm(file_list, disable=args.quiet or (n_tiles is not None and np.prod(n_tiles) > 1)): # construct output file name file_out = Path(args.output_dir) / args.output_name.format( file_path=str(file_in.relative_to(args.input_dir).parent), file_name=file_in.stem, file_ext=file_in.suffix, model_name=args.model_name, model_weights=Path(args.model_weights).stem if args.model_weights is not None else None) # checks (file_in.suffix.lower() in ('.tif', '.tiff') and file_out.suffix.lower() in ('.tif', '.tiff')) or _raise( ValueError('only tiff files supported.')) # load and predict restored image img = imread(str(file_in)) restored = model.predict(img, axes=args.input_axes, normalizer=normalizer, n_tiles=n_tiles) # restored image could be multi-channel even if input image is not axes_out = axes_check_and_normalize(args.input_axes) if restored.ndim > img.ndim: assert restored.ndim == img.ndim + 1 assert 'C' not in axes_out axes_out += 'C' # convert data type (if necessary) restored = restored.astype(np.dtype(args.output_dtype), copy=False) # save to disk if not args.dry_run: file_out.parent.mkdir(parents=True, exist_ok=True) if args.imagej_tiff: save_tiff_imagej_compatible(str(file_out), restored, axes_out) else: imsave(str(file_out), restored) processed.append((file_in, file_out)) # print summary of processed files if not args.quiet: sys.stdout.flush() sys.stderr.flush() n_processed = len(processed) len_processed = len(str(n_processed)) log('Finished processing %d %s' % (n_processed, 'files' if n_processed > 1 else 'file')) log('-' * (26 + len_processed if n_processed > 1 else 26)) for i, (file_in, file_out) in enumerate(processed): len_file = max(len(str(file_in)), len(str(file_out))) log(('{:>%d}. in : {:>%d}' % (len_processed, len_file)).format( 1 + i, str(file_in))) log(('{:>%d} out: {:>%d}' % (len_processed, len_file)).format( '', str(file_out)))
unet_n_first=48, unet_kern_size=7, train_loss='mae', train_epochs=150, train_learning_rate=1.0E-4, train_batch_size=1, train_reduce_lr={ 'patience': 5, 'factor': 0.5 }) print(config) vars(config) # In[6]: model = CARE(config=config, name=ModelName, basedir=ModelDir) #input_weights = ModelDir + ModelName + '/' +'weights_best.h5' #model.load_weights(input_weights) # In[ ]: history = model.train(X, Y, validation_data=(X_val, Y_val)) # In[ ]: print(sorted(list(history.history.keys()))) plt.figure(figsize=(16, 5)) plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']) # In[ ]:
help="the name of your model (will be stored in model_dir)", default="my_model") parser.add_argument( "--path_to_test_image", help="the path to the test high/low pair, if available", default="test") # parser.add_argument( # "-o", # "--output_filename", # help="the name of the out" # ) parser.add_argument("-p", "--plot", action="store_true", default=True) args = parser.parse_args() # Load the trained model model = CARE(config=None, name=args.model_name, basedir=args.model_dir) # Read the test images x = imread( os.path.join(args.base_dir, args.path_to_test_image, "low_snr.tif")) y = imread( os.path.join(args.base_dir, args.path_to_test_image, "high_snr.tif")) restored = model.predict(x, axes="YX") # , n_tiles=(1,4,4)) # Save the restored image save_tiff_imagej_compatible( os.path.join(args.base_dir, args.path_to_test_image, "predicted.tif"), restored, axes="YX")
basedir = '/run/user/1000/gvfs/smb-share:server=isiserver.curie.net,share=u934/equipe_bellaiche/el_alpar/210609_ON_ActTolloRNAi_lateral' basedirResults3D = basedir + '/Restored' basedirResults2D = basedir + '/Projected' basedirResults3Dextended = basedirResults3D + '/Restored' basedirResults2Dextended = basedirResults2D + '/Projected' Model_Dir = '/run/media/sancere/DATA/Lucas_Model_to_use/CARE/' #Test_change_email_adress_to_see_commit # In[3]: RestorationModel = 'CARE_restoration_SpinWideFRAP4_Bin1_3Gfp' ProjectionModel = 'CARE_projection_SpinWideFRAP4_Bin1_3Gfp' RestorationModel = CARE(config=None, name=RestorationModel, basedir=Model_Dir) ProjectionModel = ProjectionCARE(config=None, name=ProjectionModel, basedir=Model_Dir) # In[5]: #Path(basedirResults3D).mkdir(exist_ok = True) Path(basedirResults2D).mkdir(exist_ok=True) Raw_path = os.path.join(basedir, '*TIF') #tif or TIF be careful axes = 'ZYX' #projection axes : 'YX' filesRaw = glob.glob(Raw_path)