Exemple #1
0
    def create_patches(self, normalization=None):
        for ch in self.train_channels:
            n_images = len(
                list((pathlib.Path(self.out_dir) / "train_data" / "raw" /
                      "CH_{}".format(ch) / "GT").glob("*.tif")))
            print("-- Creating {} patches for channel: {}".format(
                n_images * self.n_patches_per_image, ch))
            raw_data = RawData.from_folder(
                basepath=pathlib.Path(self.out_dir) / "train_data" / "raw" /
                "CH_{}".format(ch),
                source_dirs=["low"],
                target_dir="GT",
                axes=self.axes,
            )

            if normalization is not None:
                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                    normalization=normalization,
                )
            else:

                X, Y, XY_axes = create_patches(
                    raw_data=raw_data,
                    patch_size=self.patch_size,
                    n_patches_per_image=self.n_patches_per_image,
                    save_file=self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    verbose=False,
                )

            plt.figure(figsize=(16, 4))

            rand_sel = numpy.random.randint(low=0, high=len(X), size=6)
            plot_some(X[rand_sel, 0],
                      Y[rand_sel, 0],
                      title_list=[range(6)],
                      cmap="gray")

            plt.show()

        print("Done")
        return
Exemple #2
0
def test_rawdata_from_folder(tmpdir):
    rng = np.random.RandomState(42)
    tmpdir = Path(str(tmpdir))

    n_images, img_size, img_axes = 3, (64,64), 'YX'
    data = {'X' : rng.uniform(size=(n_images,)+img_size).astype(np.float32),
            'Y' : rng.uniform(size=(n_images,)+img_size).astype(np.float32)}

    for name,images in data.items():
        (tmpdir/name).mkdir(exist_ok=True)
        for i,img in enumerate(images):
            imsave(str(tmpdir/name/('img_%02d.tif'%i)),img)

    raw_data = RawData.from_folder(str(tmpdir),['X'],'Y',img_axes)
    assert raw_data.size == n_images
    for i,(x,y,axes,mask) in enumerate(raw_data.generator()):
        assert mask is None
        assert axes == img_axes
        assert any(np.allclose(x,u) for u in data['X'])
        assert any(np.allclose(y,u) for u in data['Y'])
Exemple #3
0
def generate_2D_patch_training_data(BaseDirectory,
                                    SaveNpzDirectory,
                                    SaveName,
                                    patch_size=(512, 512),
                                    n_patches_per_image=64,
                                    transforms=None):

    raw_data = RawData.from_folder(
        basepath=BaseDirectory,
        source_dirs=['Original'],
        target_dir='BinaryMask',
        axes='YX',
    )

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=n_patches_per_image,
        transforms=transforms,
        save_file=SaveNpzDirectory + SaveName,
    )
Exemple #4
0
def data_generation(data_path, axes, patch_size, data_name):
    """Generates training data for training CARE network. `RawData` object defines how to get the pairs of low/high SNR stacks and the semantics of each axis. We have two folders "noisy" and "clean" where corresponding low and high-SNR stacks are TIFF images with identical filenames. 

	Parameters
	----------
	data_path : str
		Path of the input data containing 'noisy' and 'clean' folder
	axes : str
		Semantic order each axes
	patch_size : tuple 
		Size of the patches to crop the images
	data_name : str
		Name of the .npz file containing the pairs of images.
	
		
 
	"""
    raw_data = RawData.from_folder(
        basepath=data_path,
        source_dirs=['noisy'],
        target_dir='clean',
        axes=axes,
    )

    # Patch size that is a power of two along XYZT, or at least divisible by 8.
    # By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable.

    X, Y, XY_axes = create_patches(
        raw_data=raw_data,
        patch_size=patch_size,
        n_patches_per_image=100,
        save_file='data/' + data_name + '.npz',
    )

    assert X.shape == Y.shape
    print("shape of X,Y =", X.shape)
    print("axes  of X,Y =", XY_axes)
Exemple #5
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Exemple #6
0
from argparse import ArgumentParser
from csbdeep.data import RawData, create_patches

args = ArgumentParser()
args.add_argument("input_basepath")
args.add_argument("input_x")
args.add_argument("input_y")
args.add_argument("output_file")
args.add_argument("--patch_size_xy", default=32)
args.add_argument("--patch_size_z", default=16)
args.add_argument("--n_patches_per_image", default=750)
args.add_argument("--axes", default="ZYX")
args = args.parse_args()

data = RawData.from_folder(basepath=args.input_basepath,
                           source_dirs=[args.input_x],
                           target_dir=args.input_y,
                           axes=args.axes)

ps_xy = int(args.patch_size_xy)
ps_z = int(args.patch_size_z)

X, Y, axes = create_patches(data,
                            patch_size=(ps_z, ps_xy, ps_xy, 1),
                            n_patches_per_image=int(args.n_patches_per_image),
                            save_file=args.output_file,
                            patch_axes=args.axes + "C")
        y = imread(
            os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"))
        print(f"image size: {x.shape}")
        plt.figure(figsize=(16, 10))
        plot_some(
            np.stack([x, y]),
            title_list=[['low snr', 'high snr']],
        )
        plt.show()

    if not os.path.exists(os.path.join(args.base_dir, "test")):
        os.mkdir(os.path.join(args.base_dir, "test"))
    shutil.move(
        os.path.join(args.base_dir, "low_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "low_snr.tif"))
    shutil.move(
        os.path.join(args.base_dir, "high_snr", f"img_{img_number}.tif"),
        os.path.join(args.base_dir, "test", "high_snr.tif"))

    # Read the pairs, passing in the axis semantics
    raw_data = RawData.from_folder(basepath=args.base_dir,
                                   source_dirs=['low_snr'],
                                   target_dir='high_snr',
                                   axes='YX')

    # From the stacks, generate 2D patches, and save to the output_filename
    create_patches(raw_data=raw_data,
                   patch_size=(128, 128),
                   n_patches_per_image=512,
                   save_file=os.path.join(args.base_dir, args.output_filename))
Exemple #8
0
    def training(self):
        #Loading Files and Plot examples
        basepath = 'data/'
        training_original_dir = 'training/original/'
        training_ground_truth_dir = 'training/ground_truth/'
        validation_original_dir = 'validation/original/'
        validation_ground_truth_dir = 'validation/ground_truth/'
        import glob
        from skimage import io
        from matplotlib import pyplot as plt
        training_original_files = sorted(
            glob.glob(basepath + training_original_dir + '*.tif'))
        training_original_file = io.imread(training_original_files[0])
        training_ground_truth_files = sorted(
            glob.glob(basepath + training_ground_truth_dir + '*.tif'))
        training_ground_truth_file = io.imread(training_ground_truth_files[0])
        print("Training dataset's number of files and dimensions: ",
              len(training_original_files), training_original_file.shape,
              len(training_ground_truth_files),
              training_ground_truth_file.shape)
        training_size = len(training_original_file)

        validation_original_files = sorted(
            glob.glob(basepath + validation_original_dir + '*.tif'))
        validation_original_file = io.imread(validation_original_files[0])
        validation_ground_truth_files = sorted(
            glob.glob(basepath + validation_ground_truth_dir + '*.tif'))
        validation_ground_truth_file = io.imread(
            validation_ground_truth_files[0])
        print("Validation dataset's number of files and dimensions: ",
              len(validation_original_files), validation_original_file.shape,
              len(validation_ground_truth_files),
              validation_ground_truth_file.shape)
        validation_size = len(validation_original_file)

        if training_size == validation_size:
            size = training_size
        else:
            print(
                'Training and validation images should be of the same dimensions!'
            )

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(training_original_file)
        plt.subplot(142)
        plt.imshow(training_ground_truth_file)
        plt.subplot(143)
        plt.imshow(validation_original_file)
        plt.subplot(144)
        plt.imshow(validation_ground_truth_file)

        #preparing inputs for NN  from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset
        from csbdeep.data import RawData, create_patches, no_background_patches
        training_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[training_original_dir],
            target_dir=training_ground_truth_dir,
            axes='YX',
        )

        validation_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[validation_original_dir],
            target_dir=validation_ground_truth_dir,
            axes='YX',
        )

        # pathces will be created further in "data augmentation" step,
        # that's why patch size here is the dimensions of images and number of pathes per image is 1
        size1 = 64
        X, Y, XY_axes = create_patches(
            raw_data=training_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'training.npz',
        )

        X_val, Y_val, XY_axes = create_patches(
            raw_data=validation_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'validation.npz',
        )

        #Loading training and validation data into memory
        from csbdeep.io import load_training_data
        (X, Y), _, axes = load_training_data(basepath + 'training.npz',
                                             verbose=False)
        (X_val,
         Y_val), _, axes = load_training_data(basepath + 'validation.npz',
                                              verbose=False)
        X.shape, Y.shape, X_val.shape, Y_val.shape
        from csbdeep.utils import axes_dict
        c = axes_dict(axes)['C']
        n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

        batch = len(
            X
        )  # You should define number of batches according to the available memory
        #batch=1
        seed = 1
        from keras.preprocessing.image import ImageDataGenerator
        data_gen_args = dict(samplewise_center=False,
                             samplewise_std_normalization=False,
                             zca_whitening=False,
                             fill_mode='reflect',
                             rotation_range=30,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True
                             #width_shift_range=0.2,
                             #height_shift_range=0.2,
                             )

        # training
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        image_datagen.fit(X, augment=True, seed=seed)
        mask_datagen.fit(Y, augment=True, seed=seed)
        image_generator = image_datagen.flow(X, batch_size=batch, seed=seed)
        mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed)
        generator = zip(image_generator, mask_generator)

        # validation
        image_datagen_val = ImageDataGenerator(**data_gen_args)
        mask_datagen_val = ImageDataGenerator(**data_gen_args)
        image_datagen_val.fit(X_val, augment=True, seed=seed)
        mask_datagen_val.fit(Y_val, augment=True, seed=seed)
        image_generator_val = image_datagen_val.flow(X_val,
                                                     batch_size=batch,
                                                     seed=seed)
        mask_generator_val = mask_datagen_val.flow(Y_val,
                                                   batch_size=batch,
                                                   seed=seed)
        generator_val = zip(image_generator_val, mask_generator_val)

        # plot examples
        x, y = generator.__next__()
        x_val, y_val = generator_val.__next__()

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(x[0, :, :, 0])
        plt.subplot(142)
        plt.imshow(y[0, :, :, 0])
        plt.subplot(143)
        plt.imshow(x_val[0, :, :, 0])
        plt.subplot(144)
        plt.imshow(y_val[0, :, :, 0])

        import os
        blocks = 2
        channels = 16
        learning_rate = 0.0004
        learning_rate_decay_factor = 0.95
        epoch_size_multiplicator = 20  # basically, how often do we decrease learning rate
        epochs = 10
        #comment='_side' # adds to model_name
        import datetime
        #model path
        model_path = f'models/CSBDeep/model.h5'
        if os.path.isfile(model_path):
            print('Your model will be overwritten in the next cell')
        kernel_size = 3

        from csbdeep.models import Config, CARE
        from keras import backend as K
        best_mae = 1
        steps_per_epoch = len(X) * epoch_size_multiplicator
        validation_steps = len(X_val) * epoch_size_multiplicator
        if 'model' in globals():
            del model
        if os.path.isfile(model_path):
            os.remove(model_path)
        for i in range(epochs):
            print('Epoch:', i + 1)
            learning_rate = learning_rate * learning_rate_decay_factor
            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_kern_size=kernel_size,
                            train_learning_rate=learning_rate,
                            unet_n_depth=blocks,
                            unet_n_first=channels)
            model = CARE(config, '.', basedir='models')
            #os.remove('models/config.json')
            if i > 0:
                model.keras_model.load_weights(model_path)
            model.prepare_for_training()
            history = model.keras_model.fit_generator(
                generator,
                validation_data=generator_val,
                validation_steps=validation_steps,
                epochs=1,
                verbose=0,
                shuffle=True,
                steps_per_epoch=steps_per_epoch)
            if history.history['val_mae'][0] < best_mae:
                best_mae = history.history['val_mae'][0]
                if not os.path.exists('models/'):
                    os.makedirs('models/')
                model.keras_model.save(model_path)
                print(f'Validation MAE:{best_mae:.3E}')
            del model
            K.clear_session()

        #model path
        model_path = 'models/CSBDeep/model.h5'
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        unet_kern_size=kernel_size,
                        train_learning_rate=learning_rate,
                        unet_n_depth=blocks,
                        unet_n_first=channels)
        model = CARE(config, '.', basedir='models')
        model.keras_model.load_weights(model_path)
        model.export_TF()