Пример #1
0
def Show_loss_function(history, model_path):
    '''
  Plot the loss function

  Parameters
  ----------
  history : keras object
  model_path : (str) path to the model

  Return
  ----------
  void
  '''
    # Create figure framesize
    errorfigure = plt.figure(figsize=(16, 5))

    # Choose the values you wish to compare.
    # For example, If you wish to see another values, just replace 'loss' to 'dist_loss'
    plot_history(history, ['loss', 'val_loss'])
    errorfigure.savefig(model_path + '/training evaluation.tif')

    # convert the history.history dict to a pandas DataFrame:
    hist_df = pd.DataFrame(history.history)

    # The figure is saved into content/ as training evaluation.csv (refresh the Files if needed).
    RESULTS = model_path + '/training evaluation.csv'
    with open(RESULTS, 'w') as f:
        for key in hist_df.keys():
            f.write("%s,%s\n" % (key, hist_df[key]))
    print("Remember you can also plot:")
    print(history.history.keys())
Пример #2
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Пример #3
0
def save_results(history):
    """ Saves the loss values for a trained model in a csv file. """

    plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae'])

    # save model results in loss.csv file
    with open('loss.csv', 'w') as f:
        for key in history.history.keys():
            f.write("%s,%s\n" % (key, history.history[key]))

    # Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflow
    gc.collect()
Пример #4
0
def train(
    image,
    patch_shape=(16, 32, 32),
    ratio=0.7,
    name="untitled",
    model_dir=".",
    view_history=True,
):
    # create data generator object
    datagen = N2V_DataGenerator()

    # patch additional axes
    image = image[np.newaxis, ..., np.newaxis]

    patches = datagen.generate_patches_from_list([image], shape=patch_shape)
    logger.info(f"patch shape: {patches.shape}")
    n_patches = patches.shape[0]
    logger.info(f"{n_patches} patches generated")

    # split training set and validation set
    i = int(n_patches * ratio)
    X, X_val = patches[:i], patches[i:]

    # create training config
    config = N2VConfig(
        X,
        unet_kern_size=3,
        train_steps_per_epoch=100,
        train_epochs=100,
        train_loss="mse",
        batch_norm=True,
        train_batch_size=4,
        n2v_perc_pix=1.6,
        n2v_patch_shape=patch_shape,
        n2v_manipulator="uniform_withCP",
        n2v_neighborhood_radius=5,
    )

    model = N2V(config=config, name=name, basedir=model_dir)

    # train and save the model
    history = model.train(X, X_val)
    model.export_TF()

    plot_history(history, ["loss", "val_loss"])
Пример #5
0
def train_predict(n_tiles=(1,4,4), params=params, files=None, headless=False, **unet_config):
    """
    These advanced options can be set by keyword arguments:

    n_tiles : tuple(int)
        Number of tiles to tile the image into, if it is too large for memory.
    unet_residual : bool
        Parameter `residual` of :func:`n2v_old.nets.common_unet`. Default: ``n_channel_in == n_channel_out``
    unet_n_depth : int
        Parameter `n_depth` of :func:`n2v_old.nets.common_unet`. Default: ``2``
    unet_kern_size : int
        Parameter `kern_size` of :func:`n2v_old.nets.common_unet`. Default: ``5 if n_dim==2 else 3``
    unet_n_first : int
        Parameter `n_first` of :func:`n2v_old.nets.common_unet`. Default: ``32``
    batch_norm : bool
        Activate batch norm
    unet_last_activation : str
        Parameter `last_activation` of :func:`n2v_old.nets.common_unet`. Default: ``linear``
    train_learning_rate : float
        Learning rate for training. Default: ``0.0004``
    n2v_patch_shape : tuple
        Random patches of this shape are extracted from the given training data. Default: ``(64, 64) if n_dim==2 else (64, 64, 64)``
    n2v_manipulator : str
        Noise2Void pixel value manipulator. Default: ``uniform_withCP``
    train_reduce_lr : dict
        Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. Default: ``{'factor': 0.5, 'patience': 10}``
    n2v_manipulator : str
        Noise2Void pixel value manipulator. Default: ``uniform_withCP``
    """
    from n2v.models import N2VConfig, N2V

    from n2v.utils.n2v_utils import manipulate_val_data
    from n2v.internals.N2V_DataGenerator import N2V_DataGenerator

    if not headless:
        from csbdeep.utils import plot_history
        from matplotlib import pyplot as plt

    np = numpy

    # Init reader
    datagen = BFListReader()
    print("Loading images ...")
    if files is None:
        datagen.from_glob(params["in_dir"], params["glob"])
        files = datagen.get_file_names()
    else:
        datagen.from_file_list(files)


    print("Training ...")
    for c in params["train_channels"]:
        print("  -- Channel {}".format(c))

        imgs_for_patches = datagen.load_imgs_generator()
        imgs_for_predict = datagen.load_imgs_generator()

        img_ch = (im[..., c:c+1] for im in imgs_for_patches)
        img_ch_predict = (im[..., c:c+1] for im in imgs_for_predict)

        npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None

        patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment'])

        numpy.random.shuffle(patches)

        sep = int(len(patches)*0.9)
        X     = patches[:sep]
        X_val = patches[ sep:]

        config = N2VConfig(X,
                        train_steps_per_epoch=params["train_steps_per_epoch"],
                        train_epochs=params["train_epochs"],
                        train_loss='mse',
                        train_batch_size=params["train_batch_size"],
                        n2v_perc_pix=params["n2v_perc_pix"],
                        n2v_patch_shape=params['patch_size'],
                        n2v_manipulator='uniform_withCP',
                        n2v_neighborhood_radius=params["n2v_neighborhood_radius"], **unet_config)


        # a name used to identify the model
        model_name = '{}_ch{}'.format(params['name'], c)
        # the base directory in which our model will live
        basedir = 'models'
        # We are now creating our network model.
        model = N2V(config=config, name=model_name, basedir=params["in_dir"])

        history = model.train(X, X_val)


        val_patch = X_val[0,..., 0]
        val_patch_pred = model.predict(val_patch, axes=params["axes"])


        if "Z" in params["axes"]:
            val_patch      = val_patch.max(0)
            val_patch_pred = val_patch_pred.max(0)

        if not headless:
            f, ax = plt.subplots(1,2, figsize=(14,7))
            ax[0].imshow(val_patch,cmap='gray')
            ax[0].set_title('Validation Patch')
            ax[1].imshow(val_patch_pred,cmap='gray')
            ax[1].set_title('Validation Patch N2V')


            plt.figure(figsize=(16,5))
            plot_history(history,['loss','val_loss'])

        print("  -- Predicting channel {}".format(c))
        for f, im in zip(files, img_ch_predict):
            print("  -- {}".format(f))
            pixel_reso = get_space_time_resolution(str(f))
            res_img = []
            for t in range(len(im)):
                nt = n_tiles if "Z" in params["axes"] else n_tiles[1:]
                pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt)

                if "Z" in params["axes"]:
                    pred = pred[:, None, ...]
                res_img.append(pred)

            pred = numpy.stack(res_img)
            if "Z" not in params["axes"]:
                    pred = pred[:, None,     None, ...]

            reso      = (1 / pixel_reso.X, 1 / pixel_reso.Y )
            spacing   = pixel_reso.Z
            unit      = pixel_reso.Xunit
            finterval = pixel_reso.T

            tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX',
                                                                                                'finterval': finterval,
                                                                                                'spacing'  : spacing,
                                                                                                'unit'     : unit})
Пример #6
0
pl5_x = fig.add_subplot(2, 5, 5)
pl5_x.imshow(X[n + 4][..., 0])
pl5_y = fig.add_subplot(2, 5, 10)
pl5_y.imshow(Y[n + 4][..., 0])

model = unet(input_shape=(128, 128, 1),
             conv_kernel_size=3,
             dropout=0.5,
             filters=64,
             last_activation='sigmoid')
model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy')
model.summary()

history = model.fit(
    X,
    Y,
    batch_size=16,
    epochs=3,
    #steps_per_epoch=50,
    validation_data=(X_val, Y_val),
    #validation_steps=10,
    #validation_split=0.1,
    shuffle=True,
    validation_freq=1)

print(sorted(list(history.history.keys())))
plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss'])

model.save('myunet_2.h5')
Пример #7
0
# In[6]:

model = CARE(config=config, name=ModelName, basedir=ModelDir)
#input_weights = ModelDir + ModelName + '/' +'weights_best.h5'
#model.load_weights(input_weights)

# In[ ]:

history = model.train(X, Y, validation_data=(X_val, Y_val))

# In[ ]:

print(sorted(list(history.history.keys())))
plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss'],
             ['mse', 'val_mse', 'mae', 'val_mae'])

# In[ ]:

plt.figure(figsize=(12, 7))
_P = model.keras_model.predict(X_val[25:30])
if config.probabilistic:
    _P = _P[..., :(_P.shape[-1] // 2)]
plot_some(X_val[0], Y_val[0], _P, pmax=99.5)
plt.suptitle('5 example validation patches\n'
             'top row: input (source),  '
             'middle row: target (ground truth),  '
             'bottom row: predicted from source')

# In[ ]:
Пример #8
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Пример #9
0
                n_channel_out=1,
                probabilistic=int(args.probabilistic),
                train_batch_size=int(args.batch_size),
                unet_kern_size=3)

for i in range(int(args.num_models)):
    model = CARE(config, f"model_{i}", args.models_dir)
    train_history = model.train(tr_data[0],
                                tr_data[1],
                                validation_data=val_data,
                                epochs=int(args.epochs),
                                steps_per_epoch=int(args.steps_per_epoch))

exit()

plot_history(train_history, ["loss", "val_loss"],
             ["mse", "val_mse", "mae", "val_mae"])


def crop_to_even(img):
    shape = img.shape
    new_shape = (shape[0] - (shape[0] % 2), shape[1] - (shape[1] % 2))
    return img[0:new_shape[0], 0:new_shape[1]]


val_data_full_ims = ImageCollection(
    "training_data/val/low_snr_extracted_z/*.tif")
for i in range(len(val_data_full_ims)):
    x = crop_to_even(img_as_float(val_data_full_ims[i]))
    x = np.reshape(x, (1, *x.shape, 1))
    y_ = model.keras_model.predict(x)[0, :, :, 1]
    imshow(y_)
Пример #10
0
    def train(self, channels=None, **config_args):
        # limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        with Timer("Training"):

            for ch in channels:
                print("-- Training channel {}...".format(ch))
                (X, Y), (X_val, Y_val), axes = load_training_data(
                    self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    validation_split=0.1,
                    verbose=False,
                )

                c = axes_dict(axes)["C"]
                n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

                config = Config(
                    axes,
                    n_channel_in,
                    n_channel_out,
                    train_epochs=self.train_epochs,
                    train_steps_per_epoch=self.train_steps_per_epoch,
                    train_batch_size=self.train_batch_size,
                    probabilistic=self.probabilistic,
                    **config_args,
                )
                # Training

                # if (
                #     pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch)
                # ).exists():
                #     print("config there already")
                #     config = None

                model = CARE(
                    config,
                    "CH_{}_model".format(ch),
                    basedir=pathlib.Path(self.out_dir) / "models",
                )

                # Show learning curve and example validation results
                try:
                    history = model.train(X, Y, validation_data=(X_val, Y_val))
                except tf.errors.ResourceExhaustedError:
                    print(
                        " >> ResourceExhaustedError: Aborting...\n  Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                    )
                    return
                except tf.errors.UnknownError:
                    print(
                        " >> UnknownError: Aborting...\n  No enough memory available on GPU... are other GPU jobs running?"
                    )
                    return

                # print(sorted(list(history.history.keys())))
                plt.figure(figsize=(16, 5))
                plot_history(history, ["loss", "val_loss"],
                             ["mse", "val_mse", "mae", "val_mae"])

                plt.figure(figsize=(12, 7))
                _P = model.keras_model.predict(X_val[:5])

                if self.probabilistic:
                    _P = _P[..., 0]

                plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
                plt.suptitle("5 example validation patches\n"
                             "top row: input (source),  "
                             "middle row: target (ground truth),  "
                             "bottom row: predicted from source")

                plt.show()

                print("-- Export model for use in Fiji...")
                model.export_TF()
                print("Done")