Beispiel #1
0
def main(f_log_metrics = lambda logs:None):

    print(vars(args))
    config_tf_session()

    input_shape, output_shape, \
    x_train, y_train, x_val, y_val, x_test, y_test = prepare_data()

    model = get_model(input_shape, output_shape)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adam(args.lr),
                  metrics=['accuracy'])

    log_callback = callbacks.LambdaCallback(on_epoch_end=lambda _, logs: f_log_metrics(logs=logs))

    model.fit(x_train, y_train,
              batch_size=args.batch_size,
              epochs=args.epochs,
              verbose=args.verbose,
              validation_data=(x_val, y_val),
              callbacks=[log_callback])
    val_score = model.evaluate(x_val, y_val, verbose=0)
    print('Val loss:', val_score[0])
    print('Val accuracy:', val_score[1])

    test_score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', test_score[0])
    print('Test accuracy:', test_score[1])

    print('Saving model: mnist_model.h5')
    model.save('mnist_model.h5')

    return val_score[1], test_score[1]
def keep_limited_checkpoints(base, n=8):
    '''
    Callback to maintain only the checkpoints that have minimal validation loss
    Usage:
        keeplimited = keep_limited_checkpoints('temp/checkpoints/modelname', n=3)
    '''
    def f(epoch):
        import os
        if len(os.listdir(base)) <= n:
            return

        max_loss = 0
        rm_cp = None
        for cp in os.listdir(base):
            if cp[(len('weights') + 1):cp.find('hdf5') - 1] == 'nan':
                rm_cp = cp
                break

            if float(cp[(len('weights') + 1):cp.find('hdf5') - 1]) > max_loss:
                max_loss = float(cp[len('weights') + 1:cp.find('hdf5') - 1])
                rm_cp = cp

        if rm_cp is not None:
            #            logging.debug('Removing %s'%rm_cp)
            os.remove(os.path.join(base, rm_cp))

    return cb.LambdaCallback(on_epoch_end=lambda epoch, logs: f(epoch))
 def init_callbacks(self):
     self.callbacks.append(
         callbacks.ReduceLROnPlateau(
             **self.config.trainer.reduce_lr_on_plateau))
     if self.config.trainer.tensorboard_enabled:
         self.callbacks.append(
             callbacks.LambdaCallback(
                 on_epoch_begin=lambda epoch, loss: self.log_lr(epoch)))
Beispiel #4
0
    def init_callbacks(self):
        def schedule(epoch):
            for step in self.config.trainer.lr_schedule:
                if epoch < step.until_epoch:
                    return step.lr

        self.callbacks.append(callbacks.LearningRateScheduler(schedule))
        if self.config.trainer.tensorboard_enabled:
            self.callbacks.append(callbacks.LambdaCallback(
                on_epoch_begin=lambda epoch, loss: self.log_lr(epoch)
            ))
 def init_callbacks(self):
     self.callbacks.append(
         callbacks.ReduceLROnPlateau(
             **self.config.trainer.reduce_lr_on_plateau))
     if "model_checkpoint" in self.config.trainer:
         self.callbacks.append(
             callbacks.ModelCheckpoint(
                 save_weights_only=True,
                 **self.config.trainer.model_checkpoint))
     self.callbacks.append(
         callbacks.LambdaCallback(
             on_epoch_begin=lambda epoch, loss: self.log_lr(epoch)))
Beispiel #6
0
    def init_callbacks(self):
        def schedule(epoch):
            for step in self.config.trainer.lr_schedule:
                if epoch < step.until_epoch:
                    return step.lr

        self.callbacks.append(callbacks.LearningRateScheduler(schedule))
        if "model_checkpoint" in self.config.trainer:
            self.callbacks.append(
                callbacks.ModelCheckpoint(
                    save_weights_only=True,
                    **self.config.trainer.model_checkpoint))
        self.callbacks.append(
            callbacks.LambdaCallback(
                on_epoch_begin=lambda epoch, loss: self.log_lr(epoch)))
Beispiel #7
0
def get_cm_callback(log_dir: str, class_names: List[str]) -> callbacks.Callback:
    """Get the confusion matrix callback for plotting"""
    def log_confusion_matrix(epoch, logs):
        """Use tf.summary.image to plot confusion matrix"""
        figure = plot_confusion_matrix(logs['cm'], class_names=class_names)
        cm_image = plot_to_image(figure)
        with cm_image_writer.as_default():
            tf.summary.image("Train Confusion Matrix", cm_image, step=epoch)

        if 'val_cm' in logs:
            figure = plot_confusion_matrix(logs['val_cm'], class_names=class_names)
            cm_image = plot_to_image(figure)
            with cm_image_writer.as_default():
                tf.summary.image("Val Confusion Matrix", cm_image, step=epoch)

    cm_image_writer = tf.summary.create_file_writer(log_dir + "/cm")
    return callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
Beispiel #8
0
def LimitCheckpoints(ckpt_path, n=16):
    '''
    Callback to maintain only the checkpoints that have minimal validation loss
    Usage:
        callback = LimitCheckpoints('progress/test/ckpt', n=16)
    '''
    def f(epoch):
        if len(os.listdir(ckpt_path)) <= n:
            return

        max_loss = 0
        file_to_remove = None
        for filename in os.listdir(ckpt_path):
            cur_loss = parse_weight_path(filename)
            if cur_loss > max_loss:
                max_loss = cur_loss
                file_to_remove = filename

        if filename is not None:
            os.remove(os.path.join(ckpt_path, filename))

    return cb.LambdaCallback(on_epoch_end=lambda epoch, logs: f(epoch))
# @Author  : daiyu
# @File    : 5-8,回调函数callbacks.py
# @Software: PyCharm

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, metrics, callbacks
import tensorflow.keras.backend as K

# 示范使用LambdaCallback编写较为简单的回调函数

import json

json_log = open('./data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps(dict(epoch=epoch, **logs)) + '\n'),
    on_train_end=lambda logs: json_log.close())

# 示范使用Callback子类化编写回调函数(LearningRateScheduler的源代码)


class LearningRateScheduler(callbacks.Callback):
    def __init__(self, schedule, verbose=0):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute')
        try:
Beispiel #10
0
training_logger = TrainingLogger(epochs_count)

# Print model architecture
print(model.summary())
sys.stdout.flush()

# Train model on dataset
title('Training Model')
model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                    verbose=2,
                    use_multiprocessing=(WORKERS > 0),
                    workers=WORKERS,
                    callbacks=[
                        callbacks.LambdaCallback(
                            on_train_begin=lambda logs: training_logger.load(),
                            on_batch_end=lambda batch, logs: training_logger.
                            update(batch, logs))
                    ])

# Evaluate on test set
title('Testing Model')
score = model.evaluate_generator(generator=test_generator, verbose=0)

log('Top 1 test accuracy: {:0.2f}%'.format(score[1] * 100))
log('Top 5 test accuracy: {:0.2f}%'.format(score[2] * 100))

if not args.save:
    exit(0)

title('Saving Model')
Beispiel #11
0
# opt= optimizers.Adam(lr=init_lr, decay=init_lr / epochs)
autoencoder.compile(optimizer='adam', loss='mse')

lr_reducer = callbacks.ReduceLROnPlateau(factor=.75,
                                         cooldown=0,
                                         patience=15,
                                         verbose=0,
                                         min_lr=1e-5)
checkpoint = callbacks.ModelCheckpoint(filepath='model_ckpt.h5',
                                       monitor='val_loss',
                                       verbose=0,
                                       save_best_only=True)
test_callback = callbacks.LambdaCallback(
    on_epoch_end=lambda batch, logs: showOrigGrayColo(x_test_gray[r:r + sample
                                                                  ],
                                                      x_test_ab[r:r + sample],
                                                      autoencoder,
                                                      batch,
                                                      nShow=1))
callbacks_list = [lr_reducer, checkpoint, test_callback]
# callbacks_list= [test_callback]

history = autoencoder.fit(x_train_gray,
                          x_train_ab,
                          validation_data=(x_test_gray, x_test_ab),
                          epochs=epochs,
                          batch_size=batch_size,
                          verbose=1,
                          callbacks=callbacks_list)

showOrigGrayColo(x_test_gray[r:r + sample], x_test_ab[r:r + sample],
Beispiel #12
0
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
  returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image


cm_callback = callbacks.LambdaCallback(on_epoch_end=log_predictions)

tensorboard_callback = callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    embeddings_layer_names=conv_output,
    write_images=True)
a = model.fit_generator(data_generator(train_data, batch_size, n_input_frames),
                        steps_per_epoch=steps_per_epoch,
                        epochs=num_epochs,
                        callbacks=[tensorboard_callback, cm_callback],
                        validation_steps=validation_steps,
                        validation_data=data_generator(val_data, batch_size,
                                                       n_input_frames))