Example #1
0
def dev(args):
    import json

    # Load and parse training data
    (X,
     Y), (X_val,
          Y_val), axes = load_training_data(args.train_data,
                                            validation_split=args.valid_split,
                                            axes=args.axes,
                                            verbose=True)

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    # Model config
    print('args.resume: ', args.resume)
    if args.resume:
        # If resuming, config=None will reload the saved config
        config = None
        print('Attempting to resume')
    elif args.config:
        print('loading config from args')
        config_args = json.load(open(args.config))
        config = Config(**config_args)
    else:
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        probabilistic=args.prob,
                        train_steps_per_epoch=args.steps,
                        train_epochs=args.epochs)
        print(vars(config))

    # Load or init model
    model = CARE(config, args.model_name, basedir='models')

    # Training, tensorboard available
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    # Plot training results
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16, 5))
    plot_history(history, ['loss', 'val_loss'],
                 ['mse', 'val_mse', 'mae', 'val_mae'])
    plt.savefig(args.model_name + '_training.png')

    # Export model to be used w/ csbdeep fiji plugins and KNIME flows
    model.export_TF()
Example #2
0
    def train(self, channels=None, **config_args):
        #limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        for ch in channels:
            print("-- Training channel {}...".format(ch))
            (X, Y), (X_val, Y_val), axes = load_training_data(
                self.get_training_patch_path() /
                'CH_{}_training_patches.npz'.format(ch),
                validation_split=0.1,
                verbose=False)

            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            train_epochs=self.train_epochs,
                            train_steps_per_epoch=self.train_steps_per_epoch,
                            train_batch_size=self.train_batch_size,
                            **config_args)
            # Training
            model = CARE(config,
                         'CH_{}_model'.format(ch),
                         basedir=pathlib.Path(self.out_dir) / 'models')

            # Show learning curve and example validation results
            try:
                history = model.train(X, Y, validation_data=(X_val, Y_val))
            except tf.errors.ResourceExhaustedError:
                print(
                    "ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                )
                return

            #print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

            plt.figure(figsize=(12, 7))
            _P = model.keras_model.predict(X_val[:5])

            plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
            plt.suptitle('5 example validation patches\n'
                         'top row: input (source),  '
                         'middle row: target (ground truth),  '
                         'bottom row: predicted from source')

            plt.show()

            print("-- Export model for use in Fiji...")
            model.export_TF()
            print("Done")
Example #3
0
def gen_care_single_model(name: str) -> CARE:
    """
    Generate a single channel CARE model or retrieve the one that already exists under the name specified

    :param name: The name of the model
    :return: The CARE model
    """
    try:
        model = CARE(None, name)
    except FileNotFoundError:
        config = Config('xyzc', n_channel_in=1, n_channel_out=1)
        model = CARE(config, name)

    return model
Example #4
0
def gen_care_dual_model(name: str, batch_size: int = 16, **kwargs):
    """
    Generate a dual channel CARE model or retrieve the one that already exists under the name specified

    :param name: The name of the model
    :param batch_size: The training batch size to use (only used if the model doesn't exist yet)
    :param kwargs: Parameters to pass to the model constructor (only used if the model doesn't exist yet
    :return: The CARE model
    """
    try:
        model = CARE(None, name)
    except FileNotFoundError:
        config = Config('xyzc',
                        n_channel_in=2,
                        n_channel_out=1,
                        train_batch_size=batch_size,
                        **kwargs)
        model = CARE(config, name)

    return model
Example #5
0
def test_config():
    assert K.image_data_format() in ('channels_first', 'channels_last')

    def _with_channel(axes):
        axes = axes.upper()
        if 'C' in axes:
            return axes
        return (axes +
                'C') if K.image_data_format() == 'channels_last' else ('C' +
                                                                       axes)

    axes_list = [
        ('yx', _with_channel('YX')),
        ('ytx', _with_channel('YTX')),
        ('zyx', _with_channel('ZYX')),
        ('YX', _with_channel('YX')),
        ('XYZ', _with_channel('XYZ')),
        ('XYT', _with_channel('XYT')),
        ('SYX', _with_channel('YX')),
        ('SXYZ', _with_channel('XYZ')),
        ('SXTY', _with_channel('XTY')),
        (_with_channel('YX'), _with_channel('YX')),
        (_with_channel('XYZ'), _with_channel('XYZ')),
        (_with_channel('XTY'), _with_channel('XTY')),
        (_with_channel('SYX'), _with_channel('YX')),
        (_with_channel('STYX'), _with_channel('TYX')),
        (_with_channel('SXYZ'), _with_channel('XYZ')),
    ]

    for (axes, axes_ref) in axes_list:
        assert Config(axes).axes == axes_ref

    with pytest.raises(ValueError):
        Config('XYC')
        Config('CXY')
    with pytest.raises(ValueError):
        Config('XYZC')
        Config('CXYZ')
    with pytest.raises(ValueError):
        Config('XTYC')
        Config('CXTY')
    with pytest.raises(ValueError):
        Config('XYZT')
    with pytest.raises(ValueError):
        Config('tXYZ')
    with pytest.raises(ValueError):
        Config('XYS')
    with pytest.raises(ValueError):
        Config('XSYZ')
Example #6
0
def config_generator(**kwargs):
    assert 'axes' in kwargs
    keys, values = kwargs.keys(), kwargs.values()
    values = [v if isinstance(v, (list, tuple)) else [v] for v in values]
    for p in product(*values):
        yield Config(**dict(zip(keys, p)))
Example #7
0
plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

# In[5]:

config = config = Config(axes,
                         n_channel_in,
                         n_channel_out,
                         probabilistic=False,
                         unet_n_depth=5,
                         unet_n_first=48,
                         unet_kern_size=7,
                         train_loss='mae',
                         train_epochs=150,
                         train_learning_rate=1.0E-4,
                         train_batch_size=1,
                         train_reduce_lr={
                             'patience': 5,
                             'factor': 0.5
                         })
print(config)
vars(config)

# In[6]:

model = CARE(config=config, name=ModelName, basedir=ModelDir)
#input_weights = ModelDir + ModelName + '/' +'weights_best.h5'
#model.load_weights(input_weights)
Example #8
0
    def Train(self):

        BinaryName = 'BinaryMask/'
        RealName = 'RealMask/'
        Raw = sorted(glob.glob(self.BaseDir + '/Raw/' + '*.tif'))
        Path(self.BaseDir + '/' + BinaryName).mkdir(exist_ok=True)
        Path(self.BaseDir + '/' + RealName).mkdir(exist_ok=True)
        RealMask = sorted(glob.glob(self.BaseDir + '/' + RealName + '*.tif'))
        ValRaw = sorted(glob.glob(self.BaseDir + '/ValRaw/' + '*.tif'))
        ValRealMask = sorted(
            glob.glob(self.BaseDir + '/ValRealMask/' + '*.tif'))

        print('Instance segmentation masks:', len(RealMask))
        if len(RealMask) == 0:

            print('Making labels')
            Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))

            for fname in Mask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = label(image)

                imwrite((self.BaseDir + '/' + RealName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        Mask = sorted(glob.glob(self.BaseDir + '/' + BinaryName + '*.tif'))
        print('Semantic segmentation masks:', len(Mask))
        if len(Mask) == 0:
            print('Generating Binary images')

            RealfilesMask = sorted(
                glob.glob(self.BaseDir + '/' + RealName + '*tif'))

            for fname in RealfilesMask:

                image = imread(fname)

                Name = os.path.basename(os.path.splitext(fname)[0])

                Binaryimage = image > 0

                imwrite((self.BaseDir + '/' + BinaryName + Name + '.tif'),
                        Binaryimage.astype('uint16'))

        if self.GenerateNPZ:

            raw_data = RawData.from_folder(
                basepath=self.BaseDir,
                source_dirs=['Raw/'],
                target_dir='BinaryMask/',
                axes='ZYX',
            )

            X, Y, XY_axes = create_patches(
                raw_data=raw_data,
                patch_size=(self.PatchZ, self.PatchY, self.PatchX),
                n_patches_per_image=self.n_patches_per_image,
                save_file=self.BaseDir + self.NPZfilename + '.npz',
            )

        # Training UNET model
        if self.TrainUNET:
            print('Training UNET model')
            load_path = self.BaseDir + self.NPZfilename + '.npz'

            (X, Y), (X_val,
                     Y_val), axes = load_training_data(load_path,
                                                       validation_split=0.1,
                                                       verbose=True)
            c = axes_dict(axes)['C']
            n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_n_depth=self.depth,
                            train_epochs=self.epochs,
                            train_batch_size=self.batch_size,
                            unet_n_first=self.startfilter,
                            train_loss='mse',
                            unet_kern_size=self.kern_size,
                            train_learning_rate=self.learning_rate,
                            train_reduce_lr={
                                'patience': 5,
                                'factor': 0.5
                            })
            print(config)
            vars(config)

            model = CARE(config,
                         name='UNET' + self.model_name,
                         basedir=self.model_dir)

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + 'UNET' +
                                  self.copy_model_name + '/' +
                                  'weights_now.h5') and os.path.exists(
                                      self.model_dir + 'UNET' +
                                      self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    model.load_weights(self.copy_model_dir + 'UNET' +
                                       self.copy_model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_now.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_last.h5')

            if os.path.exists(self.model_dir + 'UNET' + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                model.load_weights(self.model_dir + 'UNET' + self.model_name +
                                   '/' + 'weights_best.h5')

            history = model.train(X, Y, validation_data=(X_val, Y_val))

            print(sorted(list(history.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(history, ['loss', 'val_loss'],
                         ['mse', 'val_mse', 'mae', 'val_mae'])

        if self.TrainSTAR:
            print('Training StarDistModel model with', self.backbone,
                  'backbone')
            self.axis_norm = (0, 1, 2)
            if self.CroppedLoad == False:
                assert len(Raw) > 1, "not enough training data"
                print(len(Raw))
                rng = np.random.RandomState(42)
                ind = rng.permutation(len(Raw))

                X_train = list(map(ReadFloat, Raw))
                Y_train = list(map(ReadInt, RealMask))
                self.Y = [
                    label(DownsampleData(y, self.DownsampleFactor))
                    for y in tqdm(Y_train)
                ]
                self.X = [
                    normalize(DownsampleData(x, self.DownsampleFactor),
                              1,
                              99.8,
                              axis=self.axis_norm) for x in tqdm(X_train)
                ]
                n_val = max(1, int(round(0.15 * len(ind))))
                ind_train, ind_val = ind[:-n_val], ind[-n_val:]

                self.X_val, self.Y_val = [self.X[i] for i in ind_val
                                          ], [self.Y[i] for i in ind_val]
                self.X_trn, self.Y_trn = [self.X[i] for i in ind_train
                                          ], [self.Y[i] for i in ind_train]

                print('number of images: %3d' % len(self.X))
                print('- training:       %3d' % len(self.X_trn))
                print('- validation:     %3d' % len(self.X_val))

            if self.CroppedLoad:
                self.X_trn = self.DataSequencer(Raw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_trn = self.DataSequencer(RealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)

                self.X_val = self.DataSequencer(ValRaw,
                                                self.axis_norm,
                                                Normalize=True,
                                                labelMe=False)
                self.Y_val = self.DataSequencer(ValRealMask,
                                                self.axis_norm,
                                                Normalize=False,
                                                labelMe=True)
                self.train_sample_cache = False

            print(Config3D.__doc__)

            anisotropy = (1, 1, 1)
            rays = Rays_GoldenSpiral(self.n_rays, anisotropy=anisotropy)

            if self.backbone == 'resnet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    resnet_n_blocks=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    resnet_kernel_size=(self.kern_size, self.kern_size,
                                        self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    resnet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1)

            if self.backbone == 'unet':

                conf = Config3D(
                    rays=rays,
                    anisotropy=anisotropy,
                    backbone=self.backbone,
                    train_epochs=self.epochs,
                    train_learning_rate=self.learning_rate,
                    unet_n_depth=self.depth,
                    train_checkpoint=self.model_dir + self.model_name + '.h5',
                    unet_kernel_size=(self.kern_size, self.kern_size,
                                      self.kern_size),
                    train_patch_size=(self.PatchZ, self.PatchX, self.PatchY),
                    train_batch_size=self.batch_size,
                    unet_n_filter_base=self.startfilter,
                    train_dist_loss='mse',
                    grid=(1, 1, 1),
                    use_gpu=self.use_gpu,
                    n_channel_in=1,
                    train_sample_cache=False)

            print(conf)
            vars(conf)

            Starmodel = StarDist3D(conf,
                                   name=self.model_name,
                                   basedir=self.model_dir)
            print(
                Starmodel._axes_tile_overlap('ZYX'),
                os.path.exists(self.model_dir + self.model_name + '/' +
                               'weights_now.h5'))

            if self.copy_model_dir is not None:
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_now.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_now.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_now.h5')
                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_last.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_last.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_last.h5')

                if os.path.exists(self.copy_model_dir + self.copy_model_name +
                                  '/' + 'weights_best.h5') and os.path.exists(
                                      self.model_dir + self.model_name + '/' +
                                      'weights_best.h5') == False:
                    print('Loading copy model')
                    Starmodel.load_weights(self.copy_model_dir +
                                           self.copy_model_name + '/' +
                                           'weights_best.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_now.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_now.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_last.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_last.h5')

            if os.path.exists(self.model_dir + self.model_name + '/' +
                              'weights_best.h5'):
                print('Loading checkpoint model')
                Starmodel.load_weights(self.model_dir + self.model_name + '/' +
                                       'weights_best.h5')

            historyStar = Starmodel.train(self.X_trn,
                                          self.Y_trn,
                                          validation_data=(self.X_val,
                                                           self.Y_val),
                                          epochs=self.epochs)
            print(sorted(list(historyStar.history.keys())))
            plt.figure(figsize=(16, 5))
            plot_history(historyStar, ['loss', 'val_loss'], [
                'dist_relevant_mae', 'val_dist_relevant_mae',
                'dist_relevant_mse', 'val_dist_relevant_mse'
            ])
Example #9
0
args.add_argument("--models_dir", default="care_probabilistic_models")
args = args.parse_args()

tr_data, _, tr_axes = load_training_data(args.train_data, validation_split=0)
val_data, _, val_axes = load_training_data(args.val_data, validation_split=0)
#val_data = np.load(args.val_data)
#val_data = (val_data["X"], val_data["Y"])
#val_axes = tr_axes # we assume that both training and validation data are saved in the same format

is_3d = bool(args.is_3d)
axes = "ZYX" if is_3d else "YX"
n_dim = 3 if is_3d else 2
config = Config(axes,
                n_dim=n_dim,
                n_channel_in=1,
                n_channel_out=1,
                probabilistic=int(args.probabilistic),
                train_batch_size=int(args.batch_size),
                unet_kern_size=3)

for i in range(int(args.num_models)):
    model = CARE(config, f"model_{i}", args.models_dir)
    train_history = model.train(tr_data[0],
                                tr_data[1],
                                validation_data=val_data,
                                epochs=int(args.epochs),
                                steps_per_epoch=int(args.steps_per_epoch))

exit()

plot_history(train_history, ["loss", "val_loss"],
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

    if args.plot:
        plt.figure(figsize=(12, 5))
        plot_some(X_val[:5], Y_val[:5])
        plt.suptitle(
            '5 example validation patches (top row: source, bottom row: target)'
        )
        plt.show()

    # Construct a CARE model, defining its configuration via a Config object
    config = Config(
        axes,
        n_channel_in,
        n_channel_out,
        probabilistic=False,  # We don't need detailed stats just yet
        train_steps_per_epoch=args.train_steps_per_epoch,
        train_epochs=args.num_epochs)
    print(config)
    vars(config)

    model = CARE(config, args.model_name, basedir=args.model_dir)

    # Use tensorboard to check the training progress with logdir = basedir
    history = model.train(X, Y, validation_data=(X_val, Y_val))

    if args.plot:
        plt.figure(figsize=(16, 5))
        plot_history(history, ['loss', 'val_loss'],
                     ['mse', 'val_mse', 'mae', 'val_mae'])
Example #11
0
n_channel_in, n_channel_out = X_train.shape[c], Y_train.shape[c]

plt.figure(figsize=(12, 5))
plot_some(X_val[:5], Y_val[:5])
plt.suptitle(
    '5 example validation patches (top row: source, bottom row: target)')

#################
# Configuration #
#################

# Config object contains: parameters of the underlying neural network, learning rate, number of parameter updates per epoch, loss function, and whether the model is probabilistic or not.

config = Config(axes,
                n_channel_in,
                n_channel_out,
                probabilistic=True,
                train_steps_per_epoch=stepPerEpoch)
print(config)
vars(config)

############
# TRAINING #
############
#Possibility to monitor the progress using TensorBoat (see https://www.tensorflow.org/guide/summaries_and_tensorboard)

# model instanciation
#model = CARE(config=None, name='my_model', basedir='models') # used to load a model
model = CARE(config, modelName, basedir=baseDir)  # used to train a new model

# training model
Example #12
0
    def training(self):
        #Loading Files and Plot examples
        basepath = 'data/'
        training_original_dir = 'training/original/'
        training_ground_truth_dir = 'training/ground_truth/'
        validation_original_dir = 'validation/original/'
        validation_ground_truth_dir = 'validation/ground_truth/'
        import glob
        from skimage import io
        from matplotlib import pyplot as plt
        training_original_files = sorted(
            glob.glob(basepath + training_original_dir + '*.tif'))
        training_original_file = io.imread(training_original_files[0])
        training_ground_truth_files = sorted(
            glob.glob(basepath + training_ground_truth_dir + '*.tif'))
        training_ground_truth_file = io.imread(training_ground_truth_files[0])
        print("Training dataset's number of files and dimensions: ",
              len(training_original_files), training_original_file.shape,
              len(training_ground_truth_files),
              training_ground_truth_file.shape)
        training_size = len(training_original_file)

        validation_original_files = sorted(
            glob.glob(basepath + validation_original_dir + '*.tif'))
        validation_original_file = io.imread(validation_original_files[0])
        validation_ground_truth_files = sorted(
            glob.glob(basepath + validation_ground_truth_dir + '*.tif'))
        validation_ground_truth_file = io.imread(
            validation_ground_truth_files[0])
        print("Validation dataset's number of files and dimensions: ",
              len(validation_original_files), validation_original_file.shape,
              len(validation_ground_truth_files),
              validation_ground_truth_file.shape)
        validation_size = len(validation_original_file)

        if training_size == validation_size:
            size = training_size
        else:
            print(
                'Training and validation images should be of the same dimensions!'
            )

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(training_original_file)
        plt.subplot(142)
        plt.imshow(training_ground_truth_file)
        plt.subplot(143)
        plt.imshow(validation_original_file)
        plt.subplot(144)
        plt.imshow(validation_ground_truth_file)

        #preparing inputs for NN  from pairs of 32bit TIFF image files with intensities in range 0..1. . Run it only once for each new dataset
        from csbdeep.data import RawData, create_patches, no_background_patches
        training_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[training_original_dir],
            target_dir=training_ground_truth_dir,
            axes='YX',
        )

        validation_data = RawData.from_folder(
            basepath=basepath,
            source_dirs=[validation_original_dir],
            target_dir=validation_ground_truth_dir,
            axes='YX',
        )

        # pathces will be created further in "data augmentation" step,
        # that's why patch size here is the dimensions of images and number of pathes per image is 1
        size1 = 64
        X, Y, XY_axes = create_patches(
            raw_data=training_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'training.npz',
        )

        X_val, Y_val, XY_axes = create_patches(
            raw_data=validation_data,
            patch_size=(size1, size1),
            patch_filter=no_background_patches(0),
            n_patches_per_image=1,
            save_file=basepath + 'validation.npz',
        )

        #Loading training and validation data into memory
        from csbdeep.io import load_training_data
        (X, Y), _, axes = load_training_data(basepath + 'training.npz',
                                             verbose=False)
        (X_val,
         Y_val), _, axes = load_training_data(basepath + 'validation.npz',
                                              verbose=False)
        X.shape, Y.shape, X_val.shape, Y_val.shape
        from csbdeep.utils import axes_dict
        c = axes_dict(axes)['C']
        n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

        batch = len(
            X
        )  # You should define number of batches according to the available memory
        #batch=1
        seed = 1
        from keras.preprocessing.image import ImageDataGenerator
        data_gen_args = dict(samplewise_center=False,
                             samplewise_std_normalization=False,
                             zca_whitening=False,
                             fill_mode='reflect',
                             rotation_range=30,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             vertical_flip=True
                             #width_shift_range=0.2,
                             #height_shift_range=0.2,
                             )

        # training
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        image_datagen.fit(X, augment=True, seed=seed)
        mask_datagen.fit(Y, augment=True, seed=seed)
        image_generator = image_datagen.flow(X, batch_size=batch, seed=seed)
        mask_generator = mask_datagen.flow(Y, batch_size=batch, seed=seed)
        generator = zip(image_generator, mask_generator)

        # validation
        image_datagen_val = ImageDataGenerator(**data_gen_args)
        mask_datagen_val = ImageDataGenerator(**data_gen_args)
        image_datagen_val.fit(X_val, augment=True, seed=seed)
        mask_datagen_val.fit(Y_val, augment=True, seed=seed)
        image_generator_val = image_datagen_val.flow(X_val,
                                                     batch_size=batch,
                                                     seed=seed)
        mask_generator_val = mask_datagen_val.flow(Y_val,
                                                   batch_size=batch,
                                                   seed=seed)
        generator_val = zip(image_generator_val, mask_generator_val)

        # plot examples
        x, y = generator.__next__()
        x_val, y_val = generator_val.__next__()

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(x[0, :, :, 0])
        plt.subplot(142)
        plt.imshow(y[0, :, :, 0])
        plt.subplot(143)
        plt.imshow(x_val[0, :, :, 0])
        plt.subplot(144)
        plt.imshow(y_val[0, :, :, 0])

        import os
        blocks = 2
        channels = 16
        learning_rate = 0.0004
        learning_rate_decay_factor = 0.95
        epoch_size_multiplicator = 20  # basically, how often do we decrease learning rate
        epochs = 10
        #comment='_side' # adds to model_name
        import datetime
        #model path
        model_path = f'models/CSBDeep/model.h5'
        if os.path.isfile(model_path):
            print('Your model will be overwritten in the next cell')
        kernel_size = 3

        from csbdeep.models import Config, CARE
        from keras import backend as K
        best_mae = 1
        steps_per_epoch = len(X) * epoch_size_multiplicator
        validation_steps = len(X_val) * epoch_size_multiplicator
        if 'model' in globals():
            del model
        if os.path.isfile(model_path):
            os.remove(model_path)
        for i in range(epochs):
            print('Epoch:', i + 1)
            learning_rate = learning_rate * learning_rate_decay_factor
            config = Config(axes,
                            n_channel_in,
                            n_channel_out,
                            unet_kern_size=kernel_size,
                            train_learning_rate=learning_rate,
                            unet_n_depth=blocks,
                            unet_n_first=channels)
            model = CARE(config, '.', basedir='models')
            #os.remove('models/config.json')
            if i > 0:
                model.keras_model.load_weights(model_path)
            model.prepare_for_training()
            history = model.keras_model.fit_generator(
                generator,
                validation_data=generator_val,
                validation_steps=validation_steps,
                epochs=1,
                verbose=0,
                shuffle=True,
                steps_per_epoch=steps_per_epoch)
            if history.history['val_mae'][0] < best_mae:
                best_mae = history.history['val_mae'][0]
                if not os.path.exists('models/'):
                    os.makedirs('models/')
                model.keras_model.save(model_path)
                print(f'Validation MAE:{best_mae:.3E}')
            del model
            K.clear_session()

        #model path
        model_path = 'models/CSBDeep/model.h5'
        config = Config(axes,
                        n_channel_in,
                        n_channel_out,
                        unet_kern_size=kernel_size,
                        train_learning_rate=learning_rate,
                        unet_n_depth=blocks,
                        unet_n_first=channels)
        model = CARE(config, '.', basedir='models')
        model.keras_model.load_weights(model_path)
        model.export_TF()
Example #13
0
def n2v_flim(project, n2v_num_pix=32):
   
   results_file = os.path.join(project, 'fit_results.hdf5')

   X, groups, mask = extract_results(results_file)
   data_shape = np.shape(X)
   print(data_shape)

   mean, std = np.mean(X), np.std(X)
   X = normalize(X, mean, std)

   XA = X #augment_data(X)

   X_val = X[0:10,...]

   # We concatenate an extra channel filled with zeros. It will be internally used for the masking.
   Y = np.concatenate((XA, np.zeros(XA.shape)), axis=-1)
   Y_val = np.concatenate((X_val.copy(), np.zeros(X_val.shape)), axis=-1) 

   n_x = X.shape[1]
   n_chan = X.shape[-1]

   manipulate_val_data(X_val, Y_val, num_pix=n_x*n_x*2/n2v_num_pix , shape=(n_x, n_x))


   # You can increase "train_steps_per_epoch" to get even better results at the price of longer computation. 
   config = Config('SYXC', 
                  n_channel_in=n_chan, 
                  n_channel_out=n_chan, 
                  unet_kern_size = 5, 
                  unet_n_depth = 2,
                  train_steps_per_epoch=200, 
                  train_loss='mae',
                  train_epochs=35,
                  batch_norm = False, 
                  train_scheme = 'Noise2Void', 
                  train_batch_size = 128, 
                  n2v_num_pix = n2v_num_pix,
                  n2v_patch_shape = (n2v_num_pix, n2v_num_pix), 
                  n2v_manipulator = 'uniform_withCP', 
                  n2v_neighborhood_radius='5')

   vars(config)

   model = CARE(config, 'n2v_model', basedir=project)

   history = model.train(XA, Y, validation_data=(X_val,Y_val))

   model.load_weights(name='weights_best.h5')

   output_project = project.replace('.flimfit','-n2v.flimfit')
   if os.path.exists(output_project) : shutil.rmtree(output_project)
   shutil.copytree(project, output_project)

   output_file = os.path.join(output_project, 'fit_results.hdf5')

   X_pred = np.zeros(X.shape)
   for i in range(X.shape[0]):
      X_pred[i,...] = denormalize(model.predict(X[i], axes='YXC',normalizer=None), mean, std)

   X_pred[mask] = np.NaN

   insert_results(output_file, X_pred, groups)
Example #14
0
keep_idx = np.isin(chans, keepers)
if np.sum(keep_idx) == 0:
    raise ValueError("Did not supply valid channel name")

print('analyzing the following channels: {}'.format(chans[keep_idx]))

x_train, x_test = x_train[:, :, :, keep_idx], x_test[:, :, :, keep_idx]
y_train, y_test = y_train[:, :, :, keep_idx], y_test[:, :, :, keep_idx]

# Taken from CARE FAQ, runs the model
axes = 'SYXC'
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = x_train.shape[c], y_train.shape[c]
config = Config(axes,
                n_channel_in,
                n_channel_out,
                probabilistic=True,
                train_epochs=num_epochs,
                train_steps_per_epoch=30)

model = CARE(config, model_name, basedir='models')
history = model.train(x_train, y_train, validation_data=(x_test, y_test))

fig = plt.figure(figsize=(30, 30))
_P = model.keras_model.predict(x_test[:5, :, :, :])
_P_mean = _P[..., :(_P.shape[-1] // 2)]
_P_scale = _P[..., (_P.shape[-1] // 2):]
plot_some(x_test[:5, :, :, 0],
          y_test[:5, :, :, :],
          _P_mean,
          _P_scale,
          pmax=99.5)
Example #15
0
def train(Training_source=".",
          Training_target=".",
          model_name="No_name",
          model_path=".",
          Visual_validation_after_training=True,
          number_of_epochs=100,
          patch_size=64,
          number_of_patches=10,
          Use_Default_Advanced_Parameters=True,
          number_of_steps=300,
          batch_size=32,
          percentage_validation=15):
    '''
  Main function of the script. Train the model an save in model_path

  Parameters
  ----------
  Training_source : (str) Path to the noisy images
  Training_target : (str) Path to the GT images
  model_name : (str) name of the model
  model_path : (str) path of the model
  Visual_validation_after_training : (bool) Predict a random image after training
  Number_of_epochs : (int) epochs
  path_size : (int) patch sizes
  number_of_patches : (int) number of patches for each image
  User_Default_Advances_Parameters : (bool) Use default parameters for the training
  number_of_steps : (int) number of steps
  batch_size : (int) batch size
  percentage_validation : (int) percentage validation

  Return
  -------
  void
  '''
    OutputFile = Training_target + "/*.tif"
    InputFile = Training_source + "/*.tif"
    base = "/content/"
    training_data = base + "/my_training_data.npz"
    if (Use_Default_Advanced_Parameters):
        print("Default advanced parameters enabled")
        batch_size = 64
        percentage_validation = 10

    percentage = percentage_validation / 100
    #here we check that no model with the same name already exist, if so delete
    if os.path.exists(model_path + '/' + model_name):
        shutil.rmtree(model_path + '/' + model_name)

    # The shape of the images.
    x = imread(InputFile)
    y = imread(OutputFile)

    print('Loaded Input images (number, width, length) =', x.shape)
    print('Loaded Output images (number, width, length) =', y.shape)
    print("Parameters initiated.")

    # RawData Object

    # This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.
    # This file is saved in .npz format and later called when loading the trainig data.

    raw_data = data.RawData.from_folder(basepath=base,
                                        source_dirs=[Training_source],
                                        target_dir=Training_target,
                                        axes='CYX',
                                        pattern='*.tif*')

    X, Y, XY_axes = data.create_patches(raw_data,
                                        patch_filter=None,
                                        patch_size=(patch_size, patch_size),
                                        n_patches_per_image=number_of_patches)

    print('Creating 2D training dataset')
    training_path = model_path + "/rawdata"
    rawdata1 = training_path + ".npz"
    np.savez(training_path, X=X, Y=Y, axes=XY_axes)

    # Load Training Data
    (X, Y), (X_val,
             Y_val), axes = load_training_data(rawdata1,
                                               validation_split=percentage,
                                               verbose=True)
    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    #Show_patches(X,Y)

    #Here we automatically define number_of_step in function of training data and batch size
    if (Use_Default_Advanced_Parameters):
        number_of_steps = int(X.shape[0] / batch_size) + 1

    print(number_of_steps)

    #Here we create the configuration file

    config = Config(axes,
                    n_channel_in,
                    n_channel_out,
                    probabilistic=False,
                    train_steps_per_epoch=number_of_steps,
                    train_epochs=number_of_epochs,
                    unet_kern_size=5,
                    unet_n_depth=3,
                    train_batch_size=batch_size,
                    train_learning_rate=0.0004)

    print(config)
    vars(config)

    # Compile the CARE model for network training
    model_training = CARE(config, model_name, basedir=model_path)

    if (Visual_validation_after_training):
        Cell_executed = 1

    import time
    start = time.time()

    #@markdown ##Start Training

    # Start Training
    history = model_training.train(X, Y, validation_data=(X_val, Y_val))

    print("Training, done.")
    if (Visual_validation_after_training):
        Predict_a_image(Training_source, Training_target, model_path,
                        model_training)

    # Displaying the time elapsed for training
    dt = time.time() - start
    min, sec = divmod(dt, 60)
    hour, min = divmod(min, 60)
    print("Time elapsed:", hour, "hour(s)", min, "min(s)", round(sec),
          "sec(s)")

    Show_loss_function(history, model_path)
Example #16
0
def train(X, Y, X_val, Y_val, axes, model_name, csv_file,probabilistic, validation_split, patch_size, **kwargs):
    """Trains CARE model with patches previously created.

    CARE model parameters configurated via 'Config' object:
      * parameters of the underlying neural network,
      * the learning rate,
      * the number of parameter updates per epoch,
      * the loss function, and
      * whether the model is probabilistic or not.

    Parameters
    ----------
    X : np.array
        Input data X for training
    Y : np.array
        Ground truth data Y for training
    X_val : np.array
        Input data X for validation
    Y_val : np.array
        Ground truth Y for validation
    axes : str
        Semantic order of the axis in the image
    train_steps_per_epoch : int
        Number of training steps per epochs
    train_epochs : int
        Number of training epochs
    model_name : str
        Name of the model to be saved after training

    Returns
    -------
    history
        Object with the loss values saved
    """
    config = Config(axes, n_channel_in=1, n_channel_out=1, probabilistic=probabilistic,  allow_new_parameters=True, **kwargs)
    model = CARE(config, model_name, basedir='models')

    # # Training
    # [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) from the current working directory with `tensorboard --logdir=.`
    # Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

    history = model.train(X, Y, validation_data=(X_val, Y_val))
    # df = pd.DataFrame()
    # df['probabilistic'] = [config.probabilistic]
    # df['batch_size'] = config.train_batch_size
    # df['train_epochs'] = config.train_epochs
    # df['lr'] = config.train_learning_rate
    # df['train_steps_per_epoch'] = config.train_steps_per_epoch
    # df['unet_last_activation'] = config.unet_last_activation
    # df['unet_n_depth'] = config.unet_n_depth
    # df['unet_n_first'] = config.unet_n_first
    # df['unet_residual'] = config.unet_residual
    # df['patch_size'] = str(patch_size)
    # df['validation_split'] = validation_split
    # last_value = len(history.history['loss']) - 1
    # #dict = kwargs
    # #data = dict.update({'loss':history.history['loss'][last_value], })
    # df['val_loss'] = history.history['val_loss'][last_value]
    # df['loss'] = history.history['loss'][last_value]
    # df['mse']=history.history['mse'][last_value]
    # df['mae']=history.history['mae'][last_value]
    # df['val_mse']=history.history['val_mse'][last_value]
    # df['val_mae']=history.history['val_mae'][last_value]
    # #df = pd.DataFrame.from_dict(data)
    # #print('File saved in:', os.getcwd()+'/'+csv_file)
    # df.to_csv('/data/'+csv_file)
    model.export_TF()
    return history
Example #17
0
    validation_split=0.1,
    verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# In[3]:

# In[4]:

config = Config(axes,
                n_channel_in,
                n_channel_out,
                unet_n_depth=4,
                train_epochs=50,
                train_steps_per_epoch=400,
                train_batch_size=16,
                train_reduce_lr={
                    'patience': 5,
                    'factor': 0.5
                })
print(config)
vars(config)

# In[5]:

model = ProjectionCARE(
    config,
    'DrosophilaDenoisingProjection',
    basedir='/local/u934/private/v_kapoor/CurieDeepLearningModels')
Example #18
0
    def train(self, channels=None, **config_args):
        # limit_gpu_memory(fraction=1)
        if channels is None:
            channels = self.train_channels

        with Timer("Training"):

            for ch in channels:
                print("-- Training channel {}...".format(ch))
                (X, Y), (X_val, Y_val), axes = load_training_data(
                    self.get_training_patch_path() /
                    "CH_{}_training_patches.npz".format(ch),
                    validation_split=0.1,
                    verbose=False,
                )

                c = axes_dict(axes)["C"]
                n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

                config = Config(
                    axes,
                    n_channel_in,
                    n_channel_out,
                    train_epochs=self.train_epochs,
                    train_steps_per_epoch=self.train_steps_per_epoch,
                    train_batch_size=self.train_batch_size,
                    probabilistic=self.probabilistic,
                    **config_args,
                )
                # Training

                # if (
                #     pathlib.Path(self.out_dir) / "models" / "CH_{}_model".format(ch)
                # ).exists():
                #     print("config there already")
                #     config = None

                model = CARE(
                    config,
                    "CH_{}_model".format(ch),
                    basedir=pathlib.Path(self.out_dir) / "models",
                )

                # Show learning curve and example validation results
                try:
                    history = model.train(X, Y, validation_data=(X_val, Y_val))
                except tf.errors.ResourceExhaustedError:
                    print(
                        " >> ResourceExhaustedError: Aborting...\n  Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?"
                    )
                    return
                except tf.errors.UnknownError:
                    print(
                        " >> UnknownError: Aborting...\n  No enough memory available on GPU... are other GPU jobs running?"
                    )
                    return

                # print(sorted(list(history.history.keys())))
                plt.figure(figsize=(16, 5))
                plot_history(history, ["loss", "val_loss"],
                             ["mse", "val_mse", "mae", "val_mae"])

                plt.figure(figsize=(12, 7))
                _P = model.keras_model.predict(X_val[:5])

                if self.probabilistic:
                    _P = _P[..., 0]

                plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
                plt.suptitle("5 example validation patches\n"
                             "top row: input (source),  "
                             "middle row: target (ground truth),  "
                             "bottom row: predicted from source")

                plt.show()

                print("-- Export model for use in Fiji...")
                model.export_TF()
                print("Done")