Example #1
0
def test_model_train(tmpdir,config):
    rng = np.random.RandomState(42)
    K.clear_session()
    X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,))
    Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,))
    model = CARE(config,basedir=str(tmpdir))
    model.train(X,Y,(X,Y))
Example #2
0
def test_model_train():
    rng = np.random.RandomState(42)
    configs = config_generator(
        axes=['YX', 'ZYX'],
        n_channel_in=[1, 2],
        n_channel_out=[1, 2],
        probabilistic=[False, True],
        # unet_residual         = [False,True],
        unet_n_depth=[1],
        unet_kern_size=[3],
        unet_n_first=[4],
        unet_last_activation=['linear'],
        # unet_input_shape      = [(None, None, 1)],
        train_loss=['mae', 'laplace'],
        train_epochs=[2],
        train_steps_per_epoch=[2],
        # train_learning_rate   = [0.0004],
        train_batch_size=[2],
        # train_tensorboard     = [True],
        # train_checkpoint      = ['weights_best.h5'],
        # train_reduce_lr       = [{'factor': 0.5, 'patience': 10}],
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        for config in configs:
            K.clear_session()
            if config.is_valid():
                X = rng.uniform(size=(4, ) + (32, ) * config.n_dim +
                                (config.n_channel_in, ))
                Y = rng.uniform(size=(4, ) + (32, ) * config.n_dim +
                                (config.n_channel_out, ))
                model = CARE(config, basedir=tmpdir)
                model.train(X, Y, (X, Y))
Example #3
0
def train_care_generated_data(model: CARE,
                              epochs: int,
                              X: np.ndarray,
                              Y: np.ndarray,
                              X_val: np.ndarray,
                              Y_val: np.ndarray,
                              steps_per_epoch: int = 400) -> None:
    """
    Train a CARE model on a dataset

    :param model: The CARE model to train
    :param epochs: The number of epochs to train for
    :param X: The training data input
    :param Y: The training data expected output
    :param X_val: The validation data input
    :param Y_val: The validation data expected output
    """
    rearr = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    rearry = lambda arr: np.moveaxis(arr, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3]
                                     )[:, :, :, :, :1]
    X = rearr(X)
    X_val = rearr(X_val)
    Y = rearry(Y)
    Y_val = rearry(Y_val)

    model.train(X,
                Y,
                validation_data=(X_val, Y_val),
                epochs=epochs,
                steps_per_epoch=steps_per_epoch)
Example #4
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Example #5
0
def dev(args):
    import json

    # Load and parse training data
    (X,
     Y), (X_val,
          Y_val), axes = load_training_data(args.train_data,
                                            validation_split=args.valid_split,
                                            axes=args.axes,
                                            verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    # Model config
    print('args.resume: ', args.resume)
    if args.resume:
        # If resuming, config=None will reload the saved config
        config = None
        print('Attempting to resume')
    elif args.config:
        print('loading config from args')
        config_args = json.load(open(args.config))
        config = Config(**config_args)
    else:
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        probabilistic=args.prob,
                        train_steps_per_epoch=args.steps,
                        train_epochs=args.epochs)
        print(vars(config))

    # Load or init model
    model = CARE(config, args.model_name, basedir='models')

    # Training, tensorboard available
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    # Plot training results
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(args.model_name + '_training.png')

    # Export model to be used w/ csbdeep fiji plugins and KNIME flows
    model.export_TF()
Example #6
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Example #7
0
                         train_reduce_lr={
                             'patience': 5,
                             'factor': 0.5
                         })
print(config)
vars(config)

# In[6]:

model = CARE(config=config, name=ModelName, basedir=ModelDir)
#input_weights = ModelDir + ModelName + '/' +'weights_best.h5'
#model.load_weights(input_weights)

# In[ ]:

history = model.train(X, Y, validation_data=(X_val, Y_val))

# In[ ]:

print(sorted(list(history.history.keys())))
plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss'],
             ['mse', 'val_mse', 'mae', 'val_mae'])

# In[ ]:

plt.figure(figsize=(12, 7))
_P = model.keras_model.predict(X_val[25:30])
if config.probabilistic:
    _P = _P[..., :(_P.shape[-1] // 2)]
plot_some(X_val[0], Y_val[0], _P, pmax=99.5)
Example #8
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Example #9
0
is_3d = bool(args.is_3d)
axes = "ZYX" if is_3d else "YX"
n_dim = 3 if is_3d else 2
config = Config(axes,
                n_dim=n_dim,
                n_channel_in=1,
                n_channel_out=1,
                probabilistic=int(args.probabilistic),
                train_batch_size=int(args.batch_size),
                unet_kern_size=3)

for i in range(int(args.num_models)):
    model = CARE(config, f"model_{i}", args.models_dir)
    train_history = model.train(tr_data[0],
                                tr_data[1],
                                validation_data=val_data,
                                epochs=int(args.epochs),
                                steps_per_epoch=int(args.steps_per_epoch))

exit()

plot_history(train_history, ["loss", "val_loss"],
             ["mse", "val_mse", "mae", "val_mae"])


def crop_to_even(img):
    shape = img.shape
    new_shape = (shape[0] - (shape[0] % 2), shape[1] - (shape[1] % 2))
    return img[0:new_shape[0], 0:new_shape[1]]

Example #10
0
def train(X, Y, X_val, Y_val, axes, model_name, csv_file,probabilistic, validation_split, patch_size, **kwargs):
    """Trains CARE model with patches previously created.

    CARE model parameters configurated via 'Config' object:
      * parameters of the underlying neural network,
      * the learning rate,
      * the number of parameter updates per epoch,
      * the loss function, and
      * whether the model is probabilistic or not.

    Parameters
    ----------
    X : np.array
        Input data X for training
    Y : np.array
        Ground truth data Y for training
    X_val : np.array
        Input data X for validation
    Y_val : np.array
        Ground truth Y for validation
    axes : str
        Semantic order of the axis in the image
    train_steps_per_epoch : int
        Number of training steps per epochs
    train_epochs : int
        Number of training epochs
    model_name : str
        Name of the model to be saved after training

    Returns
    -------
    history
        Object with the loss values saved
    """
    config = Config(axes, n_channel_in=1, n_channel_out=1, probabilistic=probabilistic,  allow_new_parameters=True, **kwargs)
    model = CARE(config, model_name, basedir='models')

    # # Training
    # [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) from the current working directory with `tensorboard --logdir=.`
    # Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

    history = model.train(X, Y, validation_data=(X_val, Y_val))
    # df = pd.DataFrame()
    # df['probabilistic'] = [config.probabilistic]
    # df['batch_size'] = config.train_batch_size
    # df['train_epochs'] = config.train_epochs
    # df['lr'] = config.train_learning_rate
    # df['train_steps_per_epoch'] = config.train_steps_per_epoch
    # df['unet_last_activation'] = config.unet_last_activation
    # df['unet_n_depth'] = config.unet_n_depth
    # df['unet_n_first'] = config.unet_n_first
    # df['unet_residual'] = config.unet_residual
    # df['patch_size'] = str(patch_size)
    # df['validation_split'] = validation_split
    # last_value = len(history.history['loss']) - 1
    # #dict = kwargs
    # #data = dict.update({'loss':history.history['loss'][last_value], })
    # df['val_loss'] = history.history['val_loss'][last_value]
    # df['loss'] = history.history['loss'][last_value]
    # df['mse']=history.history['mse'][last_value]
    # df['mae']=history.history['mae'][last_value]
    # df['val_mse']=history.history['val_mse'][last_value]
    # df['val_mae']=history.history['val_mae'][last_value]
    # #df = pd.DataFrame.from_dict(data)
    # #print('File saved in:', os.getcwd()+'/'+csv_file)
    # df.to_csv('/data/'+csv_file)
    model.export_TF()
    return history
Example #11
0
def n2v_flim(project, n2v_num_pix=32):
   
   results_file = os.path.join(project, 'fit_results.hdf5')

   X, groups, mask = extract_results(results_file)
   data_shape = np.shape(X)
   print(data_shape)

   mean, std = np.mean(X), np.std(X)
   X = normalize(X, mean, std)

   XA = X #augment_data(X)

   X_val = X[0:10,...]

   # We concatenate an extra channel filled with zeros. It will be internally used for the masking.
   Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1)
   Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) 

   n_x = X.shape[1]
   n_chan = X.shape[-1]

   manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x))


   # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. 
   config = Config('SYXC', 
                  n_channel_in=n_chan, 
                  n_channel_out=n_chan, 
                  unet_kern_size = 5, 
                  unet_n_depth = 2,
                  train_steps_per_epoch=200, 
                  train_loss='mae',
                  train_epochs=35,
                  batch_norm = False, 
                  train_scheme = 'Noise2Void', 
                  train_batch_size = 128, 
                  n2v_num_pix = n2v_num_pix,
                  n2v_patch_shape = (n2v_num_pix, n2v_num_pix), 
                  n2v_manipulator = 'uniform_withCP', 
                  n2v_neighborhood_radius='5')

   vars(config)

   model = CARE(config, 'n2v_model', basedir=project)

   history = model.train(XA, Y, validation_data=(X_val,Y_val))

   model.load_weights(name='weights_best.h5')

   output_project = project.replace('.flimfit','-n2v.flimfit')
   if os.path.exists(output_project) : shutil.rmtree(output_project)
   shutil.copytree(project, output_project)

   output_file = os.path.join(output_project, 'fit_results.hdf5')

   X_pred = np.zeros(X.shape)
   for i in range(X.shape[0]):
      X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std)

   X_pred[mask] = np.NaN

   insert_results(output_file, X_pred, groups)
Example #12
0
x_train, x_test = x_train[:, :, :, keep_idx], x_test[:, :, :, keep_idx]
y_train, y_test = y_train[:, :, :, keep_idx], y_test[:, :, :, keep_idx]

# Taken from CARE FAQ, runs the model
axes = 'SYXC'
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = x_train.shape[c], y_train.shape[c]
config = Config(axes,
                n_channel_in,
                n_channel_out,
                probabilistic=True,
                train_epochs=num_epochs,
                train_steps_per_epoch=30)

model = CARE(config, model_name, basedir='models')
history = model.train(x_train, y_train, validation_data=(x_test, y_test))

fig = plt.figure(figsize=(30, 30))
_P = model.keras_model.predict(x_test[:5, :, :, :])
_P_mean = _P[..., :(_P.shape[-1] // 2)]
_P_scale = _P[..., (_P.shape[-1] // 2):]
plot_some(x_test[:5, :, :, 0],
          y_test[:5, :, :, :],
          _P_mean,
          _P_scale,
          pmax=99.5)
fig.suptitle('5 example validation patches\n'
             'first row: input (source),  '
             'second row: target (ground truth),  '
             'third row: predicted Laplace mean,  '
             'forth row: predicted Laplace scale')
Example #13
0
    def train(self, channels=None, **config_args):
        # limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        with Timer("Training"):

            for ch in channels:
                print("-- Training channel {}...".format(ch))
                (X, Y), (X_val, Y_val), axes = load_training_data(
                    self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    validation_split=0.1,
                    verbose=False,
                )

                c = axes_dict(axes)["C"]
                n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

                config = Config(
                    axes,
                    n_channel_in,
                    n_channel_out,
                    train_epochs=self.train_epochs,
                    train_steps_per_epoch=self.train_steps_per_epoch,
                    train_batch_size=self.train_batch_size,
                    probabilistic=self.probabilistic,
                    **config_args,
                )
                # Training

                # if (
                #     pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch)
                # ).exists():
                #     print("config there already")
                #     config = None

                model = CARE(
                    config,
                    "CH_{}_model".format(ch),
                    basedir=pathlib.Path(self.out_dir) / "models",
                )

                # Show learning curve and example validation results
                try:
                    history = model.train(X, Y, validation_data=(X_val, Y_val))
                except tf.errors.ResourceExhaustedError:
                    print(
                        " >> ResourceExhaustedError: Aborting...\n  Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                    )
                    return
                except tf.errors.UnknownError:
                    print(
                        " >> UnknownError: Aborting...\n  No enough memory available on GPU... are other GPU jobs running?"
                    )
                    return

                # print(sorted(list(history.history.keys())))
                plt.figure(figsize=(16, 5))
                plot_history(history, ["loss", "val_loss"],
                             ["mse", "val_mse", "mae", "val_mae"])

                plt.figure(figsize=(12, 7))
                _P = model.keras_model.predict(X_val[:5])

                if self.probabilistic:
                    _P = _P[..., 0]

                plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
                plt.suptitle("5 example validation patches\n"
                             "top row: input (source),  "
                             "middle row: target (ground truth),  "
                             "bottom row: predicted from source")

                plt.show()

                print("-- Export model for use in Fiji...")
                model.export_TF()
                print("Done")