def train( image, patch_shape=(16, 32, 32), ratio=0.7, name="untitled", model_dir=".", view_history=True, ): # create data generator object datagen = N2V_DataGenerator() # patch additional axes image = image[np.newaxis, ..., np.newaxis] patches = datagen.generate_patches_from_list([image], shape=patch_shape)"patch shape: {patches.shape}") n_patches = patches.shape[0]"{n_patches} patches generated") # split training set and validation set i = int(n_patches * ratio) X, X_val = patches[:i], patches[i:] # create training config config = N2VConfig( X, unet_kern_size=3, train_steps_per_epoch=100, train_epochs=100, train_loss="mse", batch_norm=True, train_batch_size=4, n2v_perc_pix=1.6, n2v_patch_shape=patch_shape, n2v_manipulator="uniform_withCP", n2v_neighborhood_radius=5, ) model = N2V(config=config, name=name, basedir=model_dir) # train and save the model history = model.train(X, X_val) model.export_TF() plot_history(history, ["loss", "val_loss"])
def __prepareData(self): # split patches from the training images data = self.dataGen.generate_patches_from_list( self.getTrainingImages(), shape=(self.patchSize, self.patchSize), augment=self.dataAugmentationActivated) # create a threshold (10 % patches for the validation) threshold = int(data.shape[0] * (self.percentValidation / 100)) # split the patches into training patches and validation patches self.trainingData = data[threshold:] self.validationData = data[:threshold] if not self.numberOfSteps: self.numberOfSteps = int( self.trainingData.shape[0] / self.batchSize) + 1 print(data.shape[0], "patches created.") print(threshold, "patch images for validation (", self.percentValidation, "%).") print(self.trainingData.shape[0] - threshold, "patch images for training.") print(self.numberOfSteps) # noinspection PyTypeChecker self.config = N2VConfig(self.trainingData, unet_kern_size=self.kernelSize, train_steps_per_epoch=self.numberOfSteps, train_epochs=self.numberOfEpochs, train_loss='mse', batch_norm=True, train_batch_size=self.batchSize, n2v_patch_shape=(self.patchSize, self.patchSize), n2v_perc_pix=self.percentPixel, n2v_manipulator='uniform_withCP', unet_n_depth=self.netDepth, unet_n_first=self.uNetNFirst, train_learning_rate=self.initialLearningRate, n2v_neighborhood_radius=5) print(vars(self.config)) self.model = N2V(self.config,, basedir=self.path) print("Setup done.") print(self.config)
def train_predict(n_tiles=(1,4,4), params=params, files=None, headless=False, **unet_config): """ These advanced options can be set by keyword arguments: n_tiles : tuple(int) Number of tiles to tile the image into, if it is too large for memory. unet_residual : bool Parameter `residual` of :func:`n2v_old.nets.common_unet`. Default: ``n_channel_in == n_channel_out`` unet_n_depth : int Parameter `n_depth` of :func:`n2v_old.nets.common_unet`. Default: ``2`` unet_kern_size : int Parameter `kern_size` of :func:`n2v_old.nets.common_unet`. Default: ``5 if n_dim==2 else 3`` unet_n_first : int Parameter `n_first` of :func:`n2v_old.nets.common_unet`. Default: ``32`` batch_norm : bool Activate batch norm unet_last_activation : str Parameter `last_activation` of :func:`n2v_old.nets.common_unet`. Default: ``linear`` train_learning_rate : float Learning rate for training. Default: ``0.0004`` n2v_patch_shape : tuple Random patches of this shape are extracted from the given training data. Default: ``(64, 64) if n_dim==2 else (64, 64, 64)`` n2v_manipulator : str Noise2Void pixel value manipulator. Default: ``uniform_withCP`` train_reduce_lr : dict Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. Default: ``{'factor': 0.5, 'patience': 10}`` n2v_manipulator : str Noise2Void pixel value manipulator. Default: ``uniform_withCP`` """ from n2v.models import N2VConfig, N2V from n2v.utils.n2v_utils import manipulate_val_data from n2v.internals.N2V_DataGenerator import N2V_DataGenerator if not headless: from csbdeep.utils import plot_history from matplotlib import pyplot as plt np = numpy # Init reader datagen = BFListReader() print("Loading images ...") if files is None: datagen.from_glob(params["in_dir"], params["glob"]) files = datagen.get_file_names() else: datagen.from_file_list(files) print("Training ...") for c in params["train_channels"]: print(" -- Channel {}".format(c)) imgs_for_patches = datagen.load_imgs_generator() imgs_for_predict = datagen.load_imgs_generator() img_ch = (im[..., c:c+1] for im in imgs_for_patches) img_ch_predict = (im[..., c:c+1] for im in imgs_for_predict) npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment']) numpy.random.shuffle(patches) sep = int(len(patches)*0.9) X = patches[:sep] X_val = patches[ sep:] config = N2VConfig(X, train_steps_per_epoch=params["train_steps_per_epoch"], train_epochs=params["train_epochs"], train_loss='mse', train_batch_size=params["train_batch_size"], n2v_perc_pix=params["n2v_perc_pix"], n2v_patch_shape=params['patch_size'], n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=params["n2v_neighborhood_radius"], **unet_config) # a name used to identify the model model_name = '{}_ch{}'.format(params['name'], c) # the base directory in which our model will live basedir = 'models' # We are now creating our network model. model = N2V(config=config, name=model_name, basedir=params["in_dir"]) history = model.train(X, X_val) val_patch = X_val[0,..., 0] val_patch_pred = model.predict(val_patch, axes=params["axes"]) if "Z" in params["axes"]: val_patch = val_patch.max(0) val_patch_pred = val_patch_pred.max(0) if not headless: f, ax = plt.subplots(1,2, figsize=(14,7)) ax[0].imshow(val_patch,cmap='gray') ax[0].set_title('Validation Patch') ax[1].imshow(val_patch_pred,cmap='gray') ax[1].set_title('Validation Patch N2V') plt.figure(figsize=(16,5)) plot_history(history,['loss','val_loss']) print(" -- Predicting channel {}".format(c)) for f, im in zip(files, img_ch_predict): print(" -- {}".format(f)) pixel_reso = get_space_time_resolution(str(f)) res_img = [] for t in range(len(im)): nt = n_tiles if "Z" in params["axes"] else n_tiles[1:] pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt) if "Z" in params["axes"]: pred = pred[:, None, ...] res_img.append(pred) pred = numpy.stack(res_img) if "Z" not in params["axes"]: pred = pred[:, None, None, ...] reso = (1 / pixel_reso.X, 1 / pixel_reso.Y ) spacing = pixel_reso.Z unit = pixel_reso.Xunit finterval = pixel_reso.T tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX', 'finterval': finterval, 'spacing' : spacing, 'unit' : unit})
X = patches[:600] X_val = patches[600:] # Let's look at two patches #plt.figure(figsize=(14,7)) #plt.subplot(1,2,1) #plt.imshow(X[0,16,...,0],cmap='magma') #plt.title('Training Patch') #plt.subplot(1,2,2) #plt.imshow(X_val[0,16,...,0],cmap='magma') #plt.title('Validation Patch') # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. config = N2VConfig(X, unet_kern_size=3, train_steps_per_epoch=100, train_epochs=10, train_loss='mse', batch_norm=True, train_batch_size=4, n2v_perc_pix=1.6, n2v_patch_shape=(32, 64, 64), n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5) # Let's look at the parameters stored in the config-object. vars(config) # a name used to identify the model model_name = 'n2v_3D' # the base directory in which our model will live basedir = 'models' # We are now creating our network model model = N2V(config=config, name=model_name, basedir=basedir) history = model.train(X, X_val) print(sorted(list(history.history.keys()))) #plt.figure(figsize=(16,5)) #plot_history(history,['loss','val_loss'])
#plt.subplot(1,2,1) #plt.imshow(X[0,...]) #plt.title('Training Patch') #plt.subplot(1,2,2) #plt.imshow(X_val[0,...]) #plt.title('Validation Patch') # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation config = N2VConfig(X, unet_kern_size=3, unet_n_first=64, unet_n_depth=3, train_steps_per_epoch=5, train_epochs=25, train_loss='mse', batch_norm=True, train_batch_size=128, n2v_perc_pix=5, n2v_patch_shape=(64, 64), n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5) vars(config) # name used to identify the model model_name = 'n2v_2D_RGB' # the base directory in which our model will live basedir = 'models' # We are now creating our network model model = N2V(config, model_name, basedir=basedir) history = model.train(X, X_val)
augment=(not args.noAugment)) # The patches are non-overlapping, so we can split them into train and validation data. frac = int((len(patches)) * float(args.validationFraction) / 100.0) print("total no. of patches: " + str(len(patches)) + "\ttraining patches: " + str(len(patches) - frac) + "\tvalidation patches: " + str(frac)) X = patches[frac:] X_val = patches[:frac] config = N2VConfig(X, unet_kern_size=args.netKernelSize, train_steps_per_epoch=int(args.stepsPerEpoch), train_epochs=int(args.epochs), train_loss='mse', batch_norm=True, train_batch_size=args.batchSize, n2v_perc_pix=args.n2vPercPix, n2v_patch_shape=pshape, n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate=args.learningRate, unet_n_depth=args.netDepth, unet_n_first=args.unet_n_first) # Let's look at the parameters stored in the config-object. vars(config) # a name used to identify the model model_name = # the base directory in which our model will live basedir = args.baseDir # We are now creating our network model.
X_val = X_val[..., np.newaxis] print(X_val.shape) # In[5]: # IMPORTANT!! I add clip X = np.clip(X, 0, 255.).astype(np.uint8).astype(np.float32) X_val = np.clip(X_val, 0, 255.) config = N2VConfig(X, unet_kern_size=3, train_steps_per_epoch=400, train_epochs=200, train_loss='mse', batch_norm=True, train_batch_size=args.batch_size, n2v_perc_pix=args.perc_pix, n2v_patch_shape=(args.patch_size, args.patch_size), unet_n_first=96, unet_residual=True, n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=2) vars(config) model = N2V(config, args.exp_name, basedir=args.ckpt_dir) model.prepare_for_training(metrics=()) # We are ready to start training now. history = model.train(X, X_val) # #
datagen = N2V_DataGenerator() imgs = datagen.load_imgs_from_directory(directory=image_path, dims='ZYX') print(imgs[0].shape) # default = 32,64,64 patches = datagen.generate_patches_from_list(imgs[:1], shape=(32, 64, 64)) # default = :600 X = patches[:600] X_val = patches[600:] numberEpochs = 20 config = N2VConfig(X, unet_kern_size=3, train_steps_per_epoch=int(X.shape[0] / 128), train_epochs=numberEpochs, train_loss='mse', batch_norm=True, train_batch_size=4, n2v_perc_pix=0.198, n2v_patch_shape=(32, 64, 64), n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5) vars(config) model_name = '20epoch' model = N2V(config=config, name=model_name, basedir=image_path) history = model.train(X, X_val) print(sorted(list(history.history.keys()))) model.export_TF() # Load the image, and predict the denoised image. img = imread(os.path.join(image_path, image_name)) pred = model.predict(img, axes='ZYX', n_tiles=(2, 4, 4))
print(X.shape) X_val = X_val[..., np.newaxis] print(X_val.shape) # Let's look at one of our training and validation patches. plt.figure(figsize=(14,7)) plt.subplot(1,2,1) plt.imshow(X[0,...,0], cmap='gray') plt.title('Training Patch'); plt.subplot(1,2,2) plt.imshow(X_val[0,...,0], cmap='gray') plt.title('Validation Patch'); config = N2VConfig(X, unet_kern_size=3, train_steps_per_epoch=200, train_epochs=100, train_loss='mse', batch_norm=True,#400 , 200 train_batch_size=128, n2v_perc_pix=0.198, n2v_patch_shape=(64, 64), unet_n_first = 96, unet_residual = True, n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=2) # Let's look at the parameters stored in the config-object. #print(vars(config)) # a name used to identify the model model_name = 'BSD68_reproducability_5x5' # the base directory in which our model will live basedir = 'models' # We are now creating our network model. model = N2V(config, model_name, basedir=basedir) model.prepare_for_training(metrics=()) # We are ready to start training now.