Ejemplo n.º 1
0
def load_data(npz_data, valid_only=True, valid_split=0.2, axes='SCZXY'):

    if valid_only:
        X_val, Y_val = load_training_data(npz_data,
                                          validation_split=valid_split,
                                          axes=axes,
                                          verbose=True)[1]
        return X_val, Y_val
    else:
        (X,
         Y), (X_val,
              Y_val), axes = load_training_data(npz_data,
                                                validation_split=valid_split,
                                                axes=axes,
                                                verbose=True)
        return (X, Y), (X_val, Y_val), axes
Ejemplo n.º 2
0
def load(path_data, axes, validation_split, patch_size, data_name):
    """Loads the data patches to train.

    Parameters
    ----------
    path_data : str
        Path to input data
    axes : str
        Semantic order of the axis in the image
    validation_split : float
        Ratio of data kept for validation
    patch_size : tuple
        Size of the patches

    Returns
    -------
    X,Y : np.array
        Input data X for training, Ground truth Y for training
    X_val, Y_val : np.array
        Input data X_val for validation, Ground truth Y_val for validation
    """
    # limit GPU available memory
    # limit_gpu_memory(fraction=1/2)

    data_generation(path_data, axes, patch_size, data_name)
    (X, Y), (X_val, Y_val), axes = load_training_data('data/' + data_name + '.npz', validation_split, verbose=True)
    return (X, Y), (X_val, Y_val)
Ejemplo n.º 3
0
def original():
    mypath = Path('isonet_psf_1')
    mypath.mkdir(exist_ok=True)

    # sys.stdout = open(mypath / 'train_stdout.txt', 'w')
    # sys.stderr = open(mypath / 'train_stderr.txt', 'w')

    (X, Y), (X_val, Y_val), data_axes = load_training_data(
        mypath / 'my_training_data.npz', validation_split=0.1)
    ax = axes_dict(data_axes)

    n_train, n_val = len(X), len(X_val)
    image_size = tuple(
        X.shape[i]
        for i in ((ax['Z'], ax['Y'],
                   ax['X']) if (ax['Z'] is not None) else (ax['Y'], ax['X'])))
    n_dim = len(image_size)
    n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

    print('number of training images:\t', n_train)
    print('number of validation images:\t', n_val)
    print('image size (%dD):\t\t' % n_dim, image_size)
    print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    plt.figure(figsize=(10, 4))
    plot_some(X_val[:5], Y_val[:5])
    plt.suptitle(
        '5 example validation patches (top row: source, bottom row: target)')
    plt.savefig(mypath / 'train_1.png')

    config = Config(data_axes, n_channel_in, n_channel_out, train_epochs=200)
    print(config)
    vars(config)

    model = IsotropicCARE(config, str(mypath / 'my_model'))

    history = model.train(X, Y, validation_data=(X_val, Y_val))

    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(mypath / 'train_history.png')

    model.load_weights()  # load best weights according to validation loss

    plt.figure(figsize=(12, 7))
    _P = model.keras_model.predict(X_val[:5])
    if config.probabilistic:
        _P = _P[..., :(_P.shape[-1] // 2)]
    plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5)
    plt.suptitle('5 example validation patches\n' +
                 'top row: input (source),  ' +
                 'middle row: target (ground truth),  ' +
                 'bottom row: predicted from source')
    plt.tight_layout()
    plt.savefig(mypath / 'train_2.png')

    model.export_TF()
Ejemplo n.º 4
0
def load_data_and_model(npz_data, model_name, valid_split=0.2, axes='SCZXY'):

    model = CARE(config=None, name=model_name, basedir='models')
    X_val, Y_val = load_training_data(npz_data,
                                      validation_split=valid_split,
                                      axes=axes,
                                      verbose=True)[1]
    return X_val, Y_val, model
Ejemplo n.º 5
0
 def _create(img_size,img_axes,patch_size,patch_axes):
     U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
     X,Y,XYaxes = create_patches (
         raw_data            = RawData.from_arrays(U,V,img_axes),
         patch_size          = patch_size,
         patch_axes          = patch_axes,
         n_patches_per_image = n_patches_per_image,
         save_file           = save_file
     )
     (_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
     assert val_data is None
     assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
     _X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
     assert np.allclose(X,_X,atol=1e-6)
     assert np.allclose(Y,_Y,atol=1e-6)
     assert set(XYaxes) == set(_XYaxes)
     assert load_training_data(save_file,validation_split=0.5)[2] is not None
     assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
Ejemplo n.º 6
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Ejemplo n.º 7
0
def dev(args):
    import json

    # Load and parse training data
    (X,
     Y), (X_val,
          Y_val), axes = load_training_data(args.train_data,
                                            validation_split=args.valid_split,
                                            axes=args.axes,
                                            verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    # Model config
    print('args.resume: ', args.resume)
    if args.resume:
        # If resuming, config=None will reload the saved config
        config = None
        print('Attempting to resume')
    elif args.config:
        print('loading config from args')
        config_args = json.load(open(args.config))
        config = Config(**config_args)
    else:
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        probabilistic=args.prob,
                        train_steps_per_epoch=args.steps,
                        train_epochs=args.epochs)
        print(vars(config))

    # Load or init model
    model = CARE(config, args.model_name, basedir='models')

    # Training, tensorboard available
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    # Plot training results
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(args.model_name + '_training.png')

    # Export model to be used w/ csbdeep fiji plugins and KNIME flows
    model.export_TF()
Ejemplo n.º 8
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Ejemplo n.º 9
0
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# In[2]:

BaseDir = '/home/sancere/Kepler/CurieTrainingDatasets/Dalmiro_Laura/'
NPZdata = 'WingVeinModelUNET.npz'

ModelDir = '/home/sancere/Kepler/CurieDeepLearningModels/Dalmiro_Laura/'
ModelName = 'WingVeinUNET'
load_path = BaseDir + NPZdata

# In[3]:

(X, Y), (X_val, Y_val), axes = load_training_data(load_path,
                                                  validation_split=0.05,
                                                  verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# In[4]:

plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

# In[5]:

config = config = Config(axes,
Ejemplo n.º 10
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Ejemplo n.º 11
0
from csbdeep.models import Config, CARE
from csbdeep.data import RawData

args = ArgumentParser()
args.add_argument("train_data")
args.add_argument("val_data")
args.add_argument("--num_models", default=5)
args.add_argument("--epochs", default=100)
args.add_argument("--steps_per_epoch", default=400)
args.add_argument("--batch_size", default=32)
args.add_argument("--probabilistic", default=1)
args.add_argument("--is_3d", default=False)
args.add_argument("--models_dir", default="care_probabilistic_models")
args = args.parse_args()

tr_data, _, tr_axes = load_training_data(args.train_data, validation_split=0)
val_data, _, val_axes = load_training_data(args.val_data, validation_split=0)
#val_data = np.load(args.val_data)
#val_data = (val_data["X"], val_data["Y"])
#val_axes = tr_axes # we assume that both training and validation data are saved in the same format

is_3d = bool(args.is_3d)
axes = "ZYX" if is_3d else "YX"
n_dim = 3 if is_3d else 2
config = Config(axes,
                n_dim=n_dim,
                n_channel_in=1,
                n_channel_out=1,
                probabilistic=int(args.probabilistic),
                train_batch_size=int(args.batch_size),
                unet_kern_size=3)
Ejemplo n.º 12
0
        type=int,
        default=100)
    parser.add_argument(
        "--model_dir",
        help="the directory to store the models in",
        default="models")
    parser.add_argument(
        "--model_name",
        help="the name of your model (will be stored in model_dir)",
        default="my_model")
    parser.add_argument("-p", "--plot", action="store_true", default=True)
    args = parser.parse_args()

    # Read the npz file, split into training and validation sets
    (X, Y), (X_val, Y_val), axes = load_training_data(
        os.path.join(args.base_dir, args.input_filename),
        validation_split=0.1,
        verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    if args.plot:
        plt.figure(figsize=(12, 5))
        plot_some(X_val[:5], Y_val[:5])
        plt.suptitle(
            '5 example validation patches (top row: source, bottom row: target)'
        )
        plt.show()

    # Construct a CARE model, defining its configuration via a Config object
    config = Config(
Ejemplo n.º 13
0
##############
folderName = "data_test_beads"  #test: 'data_test'
filename = "patchTest.npz"  #test: "patchTest.npz"
modelName = "modelTest"  #test: "modelTest"
baseDir = "models"  #test: "models" name of the directory containing the model
stepPerEpoch = 100  #test: '100'  can be increased considerably for a well-train model (ex: 400)
validationSplit = 0.1  #test: 0.1 Percentage of patches conserved for validation

#################
# TRAINING DATA #
#################
#10% of validation data are used there.
(X_train,
 Y_train), (X_val,
            Y_val), axes = load_training_data(folderName + '/' + filename,
                                              validation_split=validationSplit,
                                              verbose=True)
#(X_train, Y_train), (X_val,Y_val), axes = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)

print("axes : ", axes)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X_train.shape[c], Y_train.shape[c]

plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

#################
# Configuration #
Ejemplo n.º 14
0
    def training(self):
        #Loading Files and Plot examples
        basepath = 'data/'
        training_original_dir = 'training/original/'
        training_ground_truth_dir = 'training/ground_truth/'
        validation_original_dir = 'validation/original/'
        validation_ground_truth_dir = 'validation/ground_truth/'
        import glob
        from skimage import io
        from matplotlib import pyplot as plt
        training_original_files = sorted(
            glob.glob(basepath + training_original_dir + '*.tif'))
        training_original_file = io.imread(training_original_files[0])
        training_ground_truth_files = sorted(
            glob.glob(basepath + training_ground_truth_dir + '*.tif'))
        training_ground_truth_file = io.imread(training_ground_truth_files[0])
        print("Training dataset's number of files and dimensions: ",
              len(training_original_files), training_original_file.shape,
              len(training_ground_truth_files),
              training_ground_truth_file.shape)
        training_size = len(training_original_file)

        validation_original_files = sorted(
            glob.glob(basepath + validation_original_dir + '*.tif'))
        validation_original_file = io.imread(validation_original_files[0])
        validation_ground_truth_files = sorted(
            glob.glob(basepath + validation_ground_truth_dir + '*.tif'))
        validation_ground_truth_file = io.imread(
            validation_ground_truth_files[0])
        print("Validation dataset's number of files and dimensions: ",
              len(validation_original_files), validation_original_file.shape,
              len(validation_ground_truth_files),
              validation_ground_truth_file.shape)
        validation_size = len(validation_original_file)

        if training_size == validation_size:
            size = training_size
        else:
            print(
                'Training and validation images should be of the same dimensions!'
            )

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(training_original_file)
        plt.subplot(142)
        plt.imshow(training_ground_truth_file)
        plt.subplot(143)
        plt.imshow(validation_original_file)
        plt.subplot(144)
        plt.imshow(validation_ground_truth_file)

        #preparing inputs for NN  from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset
        from csbdeep.data import RawData, create_patches, no_background_patches
        training_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[training_original_dir],
            target_dir=training_ground_truth_dir,
            axes='YX',
        )

        validation_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[validation_original_dir],
            target_dir=validation_ground_truth_dir,
            axes='YX',
        )

        # pathces will be created further in "data augmentation" step,
        # that's why patch size here is the dimensions of images and number of pathes per image is 1
        size1 = 64
        X, Y, XY_axes = create_patches(
            raw_data=training_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'training.npz',
        )

        X_val, Y_val, XY_axes = create_patches(
            raw_data=validation_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'validation.npz',
        )

        #Loading training and validation data into memory
        from csbdeep.io import load_training_data
        (X, Y), _, axes = load_training_data(basepath + 'training.npz',
                                             verbose=False)
        (X_val,
         Y_val), _, axes = load_training_data(basepath + 'validation.npz',
                                              verbose=False)
        X.shape, Y.shape, X_val.shape, Y_val.shape
        from csbdeep.utils import axes_dict
        c = axes_dict(axes)['C']
        n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

        batch = len(
            X
        )  # You should define number of batches according to the available memory
        #batch=1
        seed = 1
        from keras.preprocessing.image import ImageDataGenerator
        data_gen_args = dict(samplewise_center=False,
                             samplewise_std_normalization=False,
                             zca_whitening=False,
                             fill_mode='reflect',
                             rotation_range=30,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True
                             #width_shift_range=0.2,
                             #height_shift_range=0.2,
                             )

        # training
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        image_datagen.fit(X, augment=True, seed=seed)
        mask_datagen.fit(Y, augment=True, seed=seed)
        image_generator = image_datagen.flow(X, batch_size=batch, seed=seed)
        mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed)
        generator = zip(image_generator, mask_generator)

        # validation
        image_datagen_val = ImageDataGenerator(**data_gen_args)
        mask_datagen_val = ImageDataGenerator(**data_gen_args)
        image_datagen_val.fit(X_val, augment=True, seed=seed)
        mask_datagen_val.fit(Y_val, augment=True, seed=seed)
        image_generator_val = image_datagen_val.flow(X_val,
                                                     batch_size=batch,
                                                     seed=seed)
        mask_generator_val = mask_datagen_val.flow(Y_val,
                                                   batch_size=batch,
                                                   seed=seed)
        generator_val = zip(image_generator_val, mask_generator_val)

        # plot examples
        x, y = generator.__next__()
        x_val, y_val = generator_val.__next__()

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(x[0, :, :, 0])
        plt.subplot(142)
        plt.imshow(y[0, :, :, 0])
        plt.subplot(143)
        plt.imshow(x_val[0, :, :, 0])
        plt.subplot(144)
        plt.imshow(y_val[0, :, :, 0])

        import os
        blocks = 2
        channels = 16
        learning_rate = 0.0004
        learning_rate_decay_factor = 0.95
        epoch_size_multiplicator = 20  # basically, how often do we decrease learning rate
        epochs = 10
        #comment='_side' # adds to model_name
        import datetime
        #model path
        model_path = f'models/CSBDeep/model.h5'
        if os.path.isfile(model_path):
            print('Your model will be overwritten in the next cell')
        kernel_size = 3

        from csbdeep.models import Config, CARE
        from keras import backend as K
        best_mae = 1
        steps_per_epoch = len(X) * epoch_size_multiplicator
        validation_steps = len(X_val) * epoch_size_multiplicator
        if 'model' in globals():
            del model
        if os.path.isfile(model_path):
            os.remove(model_path)
        for i in range(epochs):
            print('Epoch:', i + 1)
            learning_rate = learning_rate * learning_rate_decay_factor
            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_kern_size=kernel_size,
                            train_learning_rate=learning_rate,
                            unet_n_depth=blocks,
                            unet_n_first=channels)
            model = CARE(config, '.', basedir='models')
            #os.remove('models/config.json')
            if i > 0:
                model.keras_model.load_weights(model_path)
            model.prepare_for_training()
            history = model.keras_model.fit_generator(
                generator,
                validation_data=generator_val,
                validation_steps=validation_steps,
                epochs=1,
                verbose=0,
                shuffle=True,
                steps_per_epoch=steps_per_epoch)
            if history.history['val_mae'][0] < best_mae:
                best_mae = history.history['val_mae'][0]
                if not os.path.exists('models/'):
                    os.makedirs('models/')
                model.keras_model.save(model_path)
                print(f'Validation MAE:{best_mae:.3E}')
            del model
            K.clear_session()

        #model path
        model_path = 'models/CSBDeep/model.h5'
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        unet_kern_size=kernel_size,
                        train_learning_rate=learning_rate,
                        unet_n_depth=blocks,
                        unet_n_first=channels)
        model = CARE(config, '.', basedir='models')
        model.keras_model.load_weights(model_path)
        model.export_TF()
Ejemplo n.º 15
0
# coding: utf-8

# In[1]:

from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, ProjectionCARE

# In[2]:

(X, Y), (X_val, Y_val), axes = load_training_data(
    '/local/u934/private/v_kapoor/CurieTrainingDatasets/Drosophilla/DenoisingProjection.npz',
    validation_split=0.1,
    verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# In[3]:

# In[4]:

config = Config(axes,
                n_channel_in,
                n_channel_out,
                unet_n_depth=4,
                train_epochs=50,
                train_steps_per_epoch=400,
Ejemplo n.º 16
0
print(imgArray[0].shape)
print(labelArray[0].shape)

raw_data = RawData.from_arrays(imgArray, labelArray, axes='YX')

X, Y, XY_axes = create_patches(
    raw_data=raw_data,
    patch_size=(128, 128, 1),
    patch_axes='YXC',
    n_patches_per_image=25,
    save_file=
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz'
)

(X, Y), (X_val, Y_val), axes = load_training_data(
    '/mnt/AE3205C73205958D/Data/3dliver_local/pc_adult/2d_slices/imagesXY/image_full/mydata_128x128patch.npz',
    validation_split=0.1,
    verbose=True)

n = 10
print(X[n].shape)
fig = plt.figure()
pl1_x = fig.add_subplot(2, 5, 1)
pl1_x.imshow(X[n][..., 0])
pl1_y = fig.add_subplot(2, 5, 6)
pl1_y.imshow(Y[n][..., 0])
pl2_x = fig.add_subplot(2, 5, 2)
pl2_x.imshow(X[n + 1][..., 0])
pl2_y = fig.add_subplot(2, 5, 7)
pl2_y.imshow(Y[n + 1][..., 0])
pl3_x = fig.add_subplot(2, 5, 3)
pl3_x.imshow(X[n + 2][..., 0])
Ejemplo n.º 17
0
    def train(self, channels=None, **config_args):
        # limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        with Timer("Training"):

            for ch in channels:
                print("-- Training channel {}...".format(ch))
                (X, Y), (X_val, Y_val), axes = load_training_data(
                    self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    validation_split=0.1,
                    verbose=False,
                )

                c = axes_dict(axes)["C"]
                n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

                config = Config(
                    axes,
                    n_channel_in,
                    n_channel_out,
                    train_epochs=self.train_epochs,
                    train_steps_per_epoch=self.train_steps_per_epoch,
                    train_batch_size=self.train_batch_size,
                    probabilistic=self.probabilistic,
                    **config_args,
                )
                # Training

                # if (
                #     pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch)
                # ).exists():
                #     print("config there already")
                #     config = None

                model = CARE(
                    config,
                    "CH_{}_model".format(ch),
                    basedir=pathlib.Path(self.out_dir) / "models",
                )

                # Show learning curve and example validation results
                try:
                    history = model.train(X, Y, validation_data=(X_val, Y_val))
                except tf.errors.ResourceExhaustedError:
                    print(
                        " >> ResourceExhaustedError: Aborting...\n  Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                    )
                    return
                except tf.errors.UnknownError:
                    print(
                        " >> UnknownError: Aborting...\n  No enough memory available on GPU... are other GPU jobs running?"
                    )
                    return

                # print(sorted(list(history.history.keys())))
                plt.figure(figsize=(16, 5))
                plot_history(history, ["loss", "val_loss"],
                             ["mse", "val_mse", "mae", "val_mae"])

                plt.figure(figsize=(12, 7))
                _P = model.keras_model.predict(X_val[:5])

                if self.probabilistic:
                    _P = _P[..., 0]

                plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
                plt.suptitle("5 example validation patches\n"
                             "top row: input (source),  "
                             "middle row: target (ground truth),  "
                             "bottom row: predicted from source")

                plt.show()

                print("-- Export model for use in Fiji...")
                model.export_TF()
                print("Done")