Esempio n. 1
0
def train(
    image,
    patch_shape=(16, 32, 32),
    ratio=0.7,
    name="untitled",
    model_dir=".",
    view_history=True,
):
    # create data generator object
    datagen = N2V_DataGenerator()

    # patch additional axes
    image = image[np.newaxis, ..., np.newaxis]

    patches = datagen.generate_patches_from_list([image], shape=patch_shape)
    logger.info(f"patch shape: {patches.shape}")
    n_patches = patches.shape[0]
    logger.info(f"{n_patches} patches generated")

    # split training set and validation set
    i = int(n_patches * ratio)
    X, X_val = patches[:i], patches[i:]

    # create training config
    config = N2VConfig(
        X,
        unet_kern_size=3,
        train_steps_per_epoch=100,
        train_epochs=100,
        train_loss="mse",
        batch_norm=True,
        train_batch_size=4,
        n2v_perc_pix=1.6,
        n2v_patch_shape=patch_shape,
        n2v_manipulator="uniform_withCP",
        n2v_neighborhood_radius=5,
    )

    model = N2V(config=config, name=name, basedir=model_dir)

    # train and save the model
    history = model.train(X, X_val)
    model.export_TF()

    plot_history(history, ["loss", "val_loss"])
Esempio n. 2
0
File: n2v.py Progetto: Natuy/dl4mic
 def __prepareData(self):
     # split patches from the training images
     data = self.dataGen.generate_patches_from_list(
         self.getTrainingImages(),
         shape=(self.patchSize, self.patchSize),
         augment=self.dataAugmentationActivated)
     # create a threshold (10 % patches for the validation)
     threshold = int(data.shape[0] * (self.percentValidation / 100))
     # split the patches into training patches and validation patches
     self.trainingData = data[threshold:]
     self.validationData = data[:threshold]
     if not self.numberOfSteps:
         self.numberOfSteps = int(
             self.trainingData.shape[0] / self.batchSize) + 1
     print(data.shape[0], "patches created.")
     print(threshold, "patch images for validation (",
           self.percentValidation, "%).")
     print(self.trainingData.shape[0] - threshold,
           "patch images for training.")
     print(self.numberOfSteps)
     # noinspection PyTypeChecker
     self.config = N2VConfig(self.trainingData,
                             unet_kern_size=self.kernelSize,
                             train_steps_per_epoch=self.numberOfSteps,
                             train_epochs=self.numberOfEpochs,
                             train_loss='mse',
                             batch_norm=True,
                             train_batch_size=self.batchSize,
                             n2v_patch_shape=(self.patchSize,
                                              self.patchSize),
                             n2v_perc_pix=self.percentPixel,
                             n2v_manipulator='uniform_withCP',
                             unet_n_depth=self.netDepth,
                             unet_n_first=self.uNetNFirst,
                             train_learning_rate=self.initialLearningRate,
                             n2v_neighborhood_radius=5)
     print(vars(self.config))
     self.model = N2V(self.config, self.name, basedir=self.path)
     print("Setup done.")
     print(self.config)
Esempio n. 3
0
def train_predict(n_tiles=(1,4,4), params=params, files=None, headless=False, **unet_config):
    """
    These advanced options can be set by keyword arguments:

    n_tiles : tuple(int)
        Number of tiles to tile the image into, if it is too large for memory.
    unet_residual : bool
        Parameter `residual` of :func:`n2v_old.nets.common_unet`. Default: ``n_channel_in == n_channel_out``
    unet_n_depth : int
        Parameter `n_depth` of :func:`n2v_old.nets.common_unet`. Default: ``2``
    unet_kern_size : int
        Parameter `kern_size` of :func:`n2v_old.nets.common_unet`. Default: ``5 if n_dim==2 else 3``
    unet_n_first : int
        Parameter `n_first` of :func:`n2v_old.nets.common_unet`. Default: ``32``
    batch_norm : bool
        Activate batch norm
    unet_last_activation : str
        Parameter `last_activation` of :func:`n2v_old.nets.common_unet`. Default: ``linear``
    train_learning_rate : float
        Learning rate for training. Default: ``0.0004``
    n2v_patch_shape : tuple
        Random patches of this shape are extracted from the given training data. Default: ``(64, 64) if n_dim==2 else (64, 64, 64)``
    n2v_manipulator : str
        Noise2Void pixel value manipulator. Default: ``uniform_withCP``
    train_reduce_lr : dict
        Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. Default: ``{'factor': 0.5, 'patience': 10}``
    n2v_manipulator : str
        Noise2Void pixel value manipulator. Default: ``uniform_withCP``
    """
    from n2v.models import N2VConfig, N2V

    from n2v.utils.n2v_utils import manipulate_val_data
    from n2v.internals.N2V_DataGenerator import N2V_DataGenerator

    if not headless:
        from csbdeep.utils import plot_history
        from matplotlib import pyplot as plt

    np = numpy

    # Init reader
    datagen = BFListReader()
    print("Loading images ...")
    if files is None:
        datagen.from_glob(params["in_dir"], params["glob"])
        files = datagen.get_file_names()
    else:
        datagen.from_file_list(files)


    print("Training ...")
    for c in params["train_channels"]:
        print("  -- Channel {}".format(c))

        imgs_for_patches = datagen.load_imgs_generator()
        imgs_for_predict = datagen.load_imgs_generator()

        img_ch = (im[..., c:c+1] for im in imgs_for_patches)
        img_ch_predict = (im[..., c:c+1] for im in imgs_for_predict)

        npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None

        patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment'])

        numpy.random.shuffle(patches)

        sep = int(len(patches)*0.9)
        X     = patches[:sep]
        X_val = patches[ sep:]

        config = N2VConfig(X,
                        train_steps_per_epoch=params["train_steps_per_epoch"],
                        train_epochs=params["train_epochs"],
                        train_loss='mse',
                        train_batch_size=params["train_batch_size"],
                        n2v_perc_pix=params["n2v_perc_pix"],
                        n2v_patch_shape=params['patch_size'],
                        n2v_manipulator='uniform_withCP',
                        n2v_neighborhood_radius=params["n2v_neighborhood_radius"], **unet_config)


        # a name used to identify the model
        model_name = '{}_ch{}'.format(params['name'], c)
        # the base directory in which our model will live
        basedir = 'models'
        # We are now creating our network model.
        model = N2V(config=config, name=model_name, basedir=params["in_dir"])

        history = model.train(X, X_val)


        val_patch = X_val[0,..., 0]
        val_patch_pred = model.predict(val_patch, axes=params["axes"])


        if "Z" in params["axes"]:
            val_patch      = val_patch.max(0)
            val_patch_pred = val_patch_pred.max(0)

        if not headless:
            f, ax = plt.subplots(1,2, figsize=(14,7))
            ax[0].imshow(val_patch,cmap='gray')
            ax[0].set_title('Validation Patch')
            ax[1].imshow(val_patch_pred,cmap='gray')
            ax[1].set_title('Validation Patch N2V')


            plt.figure(figsize=(16,5))
            plot_history(history,['loss','val_loss'])

        print("  -- Predicting channel {}".format(c))
        for f, im in zip(files, img_ch_predict):
            print("  -- {}".format(f))
            pixel_reso = get_space_time_resolution(str(f))
            res_img = []
            for t in range(len(im)):
                nt = n_tiles if "Z" in params["axes"] else n_tiles[1:]
                pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt)

                if "Z" in params["axes"]:
                    pred = pred[:, None, ...]
                res_img.append(pred)

            pred = numpy.stack(res_img)
            if "Z" not in params["axes"]:
                    pred = pred[:, None,     None, ...]

            reso      = (1 / pixel_reso.X, 1 / pixel_reso.Y )
            spacing   = pixel_reso.Z
            unit      = pixel_reso.Xunit
            finterval = pixel_reso.T

            tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX',
                                                                                                'finterval': finterval,
                                                                                                'spacing'  : spacing,
                                                                                                'unit'     : unit})
Esempio n. 4
0
X = patches[:600]
X_val = patches[600:]

# Let's look at two patches
#plt.figure(figsize=(14,7))
#plt.subplot(1,2,1)
#plt.imshow(X[0,16,...,0],cmap='magma')
#plt.title('Training Patch')
#plt.subplot(1,2,2)
#plt.imshow(X_val[0,16,...,0],cmap='magma')
#plt.title('Validation Patch')
#plt.show()

# You can increase "train_steps_per_epoch" to get even better results at the price of longer computation.
config = N2VConfig(X, unet_kern_size=3,
    train_steps_per_epoch=100, train_epochs=10, train_loss='mse', batch_norm=True,
    train_batch_size=4, n2v_perc_pix=1.6, n2v_patch_shape=(32, 64, 64),
    n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5)
# Let's look at the parameters stored in the config-object.
vars(config)

# a name used to identify the model
model_name = 'n2v_3D'
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model
model = N2V(config=config, name=model_name, basedir=basedir)

history = model.train(X, X_val)
print(sorted(list(history.history.keys())))
#plt.figure(figsize=(16,5))
#plot_history(history,['loss','val_loss'])
Esempio n. 5
0
#plt.subplot(1,2,1)
#plt.imshow(X[0,...])
#plt.title('Training Patch')
#plt.subplot(1,2,2)
#plt.imshow(X_val[0,...])
#plt.title('Validation Patch')
#plt.show()

# You can increase "train_steps_per_epoch" to get even better results at the price of longer computation
config = N2VConfig(X,
                   unet_kern_size=3,
                   unet_n_first=64,
                   unet_n_depth=3,
                   train_steps_per_epoch=5,
                   train_epochs=25,
                   train_loss='mse',
                   batch_norm=True,
                   train_batch_size=128,
                   n2v_perc_pix=5,
                   n2v_patch_shape=(64, 64),
                   n2v_manipulator='uniform_withCP',
                   n2v_neighborhood_radius=5)
vars(config)

# name used to identify the model
model_name = 'n2v_2D_RGB'
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model
model = N2V(config, model_name, basedir=basedir)
history = model.train(X, X_val)
Esempio n. 6
0
                                             augment=(not args.noAugment))

# The patches are non-overlapping, so we can split them into train and validation data.
frac = int((len(patches)) * float(args.validationFraction) / 100.0)
print("total no. of patches: " + str(len(patches)) + "\ttraining patches: " +
      str(len(patches) - frac) + "\tvalidation patches: " + str(frac))
X = patches[frac:]
X_val = patches[:frac]

config = N2VConfig(X,
                   unet_kern_size=args.netKernelSize,
                   train_steps_per_epoch=int(args.stepsPerEpoch),
                   train_epochs=int(args.epochs),
                   train_loss='mse',
                   batch_norm=True,
                   train_batch_size=args.batchSize,
                   n2v_perc_pix=args.n2vPercPix,
                   n2v_patch_shape=pshape,
                   n2v_manipulator='uniform_withCP',
                   n2v_neighborhood_radius=5,
                   train_learning_rate=args.learningRate,
                   unet_n_depth=args.netDepth,
                   unet_n_first=args.unet_n_first)

# Let's look at the parameters stored in the config-object.
vars(config)

# a name used to identify the model
model_name = args.name
# the base directory in which our model will live
basedir = args.baseDir
# We are now creating our network model.
Esempio n. 7
0
X_val = X_val[..., np.newaxis]
print(X_val.shape)

# In[5]:

# IMPORTANT!! I add clip
X = np.clip(X, 0, 255.).astype(np.uint8).astype(np.float32)
X_val = np.clip(X_val, 0, 255.)

config = N2VConfig(X,
                   unet_kern_size=3,
                   train_steps_per_epoch=400,
                   train_epochs=200,
                   train_loss='mse',
                   batch_norm=True,
                   train_batch_size=args.batch_size,
                   n2v_perc_pix=args.perc_pix,
                   n2v_patch_shape=(args.patch_size, args.patch_size),
                   unet_n_first=96,
                   unet_residual=True,
                   n2v_manipulator='uniform_withCP',
                   n2v_neighborhood_radius=2)
vars(config)

model = N2V(config, args.exp_name, basedir=args.ckpt_dir)
model.prepare_for_training(metrics=())

# We are ready to start training now.
history = model.train(X, X_val)
#
#
Esempio n. 8
0
datagen = N2V_DataGenerator()
imgs = datagen.load_imgs_from_directory(directory=image_path, dims='ZYX')
print(imgs[0].shape)
# default = 32,64,64
patches = datagen.generate_patches_from_list(imgs[:1], shape=(32, 64, 64))

# default = :600
X = patches[:600]
X_val = patches[600:]
numberEpochs = 20
config = N2VConfig(X,
                   unet_kern_size=3,
                   train_steps_per_epoch=int(X.shape[0] / 128),
                   train_epochs=numberEpochs,
                   train_loss='mse',
                   batch_norm=True,
                   train_batch_size=4,
                   n2v_perc_pix=0.198,
                   n2v_patch_shape=(32, 64, 64),
                   n2v_manipulator='uniform_withCP',
                   n2v_neighborhood_radius=5)
vars(config)
model_name = '20epoch'
model = N2V(config=config, name=model_name, basedir=image_path)
history = model.train(X, X_val)
print(sorted(list(history.history.keys())))
model.export_TF()

# Load the image, and predict the denoised image.
img = imread(os.path.join(image_path, image_name))
pred = model.predict(img, axes='ZYX', n_tiles=(2, 4, 4))
print(X.shape)
X_val = X_val[..., np.newaxis]
print(X_val.shape)
# Let's look at one of our training and validation patches.
plt.figure(figsize=(14,7))
plt.subplot(1,2,1)
plt.imshow(X[0,...,0], cmap='gray')
plt.title('Training Patch');
plt.subplot(1,2,2)
plt.imshow(X_val[0,...,0], cmap='gray')
plt.title('Validation Patch');
plt.show()

config = N2VConfig(X, unet_kern_size=3, 
                   train_steps_per_epoch=200, train_epochs=100, train_loss='mse', batch_norm=True,#400 , 200 
                   train_batch_size=128, n2v_perc_pix=0.198, n2v_patch_shape=(64, 64), 
                   unet_n_first = 96,
                   unet_residual = True,
                   n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=2)

# Let's look at the parameters stored in the config-object.
#print(vars(config))

# a name used to identify the model
model_name = 'BSD68_reproducability_5x5'
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model.
model = N2V(config, model_name, basedir=basedir)
model.prepare_for_training(metrics=())

# We are ready to start training now.