Пример #1
0
    epochs_passed = epoch - config.startLRdecay
    decay_step = 2e-4 / (config.epochs - config.startLRdecay)
    return 2e-4 - epochs_remaining * decay_step

LRscheduler = callbacks.MultiLRScheduler(scheduler, training_models=[model.d_A, model.d_B, model.combined])
# Generate Callbacks
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, write_graph=True, update_freq='epoch')
start_tensorboard = callbacks.StartTensorBoard(LOG_DIR)

prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps', stateful_metrics=None)
log_code = callbacks.LogCode(LOG_DIR, './trainer')
copy_keras = callbacks.CopyKerasModel(MODEL_DIR, LOG_DIR)

saving = callbacks.MultiModelCheckpoint(MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
                                        monitor='val_ssim', verbose=1, freq='epoch', mode='max', save_best_only=False,
                                        save_weights_only=True,
                                        multi_models=[('g_AB', g_AB), ('g_BA', g_BA), ('d_A', d_A), ('d_B', d_B)])
                                            restore_best_weights=True, verbose=1)

image_gen = callbacks.GenerateImages(g_AB, test_X, test_Y, LOG_DIR, interval=int(dataset_count/config.bs))

# Fit the model
model.fit(train_X, train_Y,
    batch_size=config.bs,
    steps_per_epoch=(dataset_count // config.bs),
          epochs=config.epochs,
          validation_data=(test_X, test_Y),
          validation_steps=10,
          callbacks=[log_code, tensorboard, prog_bar, image_gen, saving,
                     copy_keras, start_tensorboard, LRscheduler])
Пример #2
0
log_code = callbacks.LogCode(LOG_DIR, './trainer')
copy_keras = callbacks.CopyKerasModel(MODEL_DIR, LOG_DIR)

saving = callbacks.MultiModelCheckpoint(
    MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
    monitor='val_ssim',
    verbose=1,
    freq='epoch',
    mode='max',
    save_best_only=False,
    save_weights_only=True,
    multi_models=[('g_AB', g_AB), ('g_BA', g_BA), ('d_A', d_A), ('d_B', d_B)])

image_gen = callbacks.GenerateImages(g_AB,
                                     test_X,
                                     test_Y,
                                     LOG_DIR,
                                     interval=int(dataset_count / config.bs),
                                     z_shape=(config.latent_z_dim, ))

# Fit the model
model.fit(train_X,
          train_Y,
          batch_size=config.bs,
          steps_per_epoch=(dataset_count // config.bs),
          epochs=config.epochs,
          validation_data=(test_X, test_Y),
          validation_steps=10,
          callbacks=[
              log_code, tensorboard, prog_bar, image_gen, saving, copy_keras,
              start_tensorboard
          ])
Пример #3
0
                          'model.{epoch:02d}-{val_loss:.5f}.h5')
saving = tf.keras.callbacks.ModelCheckpoint(model_path,
                                            monitor='val_loss',
                                            verbose=1,
                                            save_freq='epoch',
                                            save_best_only=True,
                                            save_weights_only=True)

# callbacks -- log training
write_freq = int(train_count / config.batch_size / 10)
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=config.job_dir,
                                             write_graph=True,
                                             update_freq=write_freq)
image_gen_val = callbacks.GenerateImages(model,
                                         validation_dataset,
                                         config.job_dir,
                                         interval=write_freq,
                                         postfix='val')
image_gen = callbacks.GenerateImages(model,
                                     train_dataset,
                                     config.job_dir,
                                     interval=write_freq,
                                     postfix='train')

# callbacks -- start tensorboard
start_tensorboard = callbacks.StartTensorBoard(config.job_dir)

# callbacks -- log code and trained models
log_code = callbacks.LogCode(config.job_dir, './trainer')
copy_keras = callbacks.CopyKerasModel(config.model_dir, config.job_dir)
Пример #4
0
    monitor='val_ssim',
    mode='max',
    factor=0.5,
    patience=3,
    min_lr=0.000002)
early_stopping = callbacks.MultiEarlyStopping(
    multi_models=[g_AB, g_BA, d_A, d_B],
    full_model=model,
    monitor='val_ssim',
    mode='max',
    patience=1,
    restore_best_weights=True,
    verbose=1)

image_gen = callbacks.GenerateImages(g_AB,
                                     validation_dataset,
                                     LOG_DIR,
                                     interval=int(iq_count / config.bs))

# Fit the model
model.fit(iq_dataset,
          dtce_dataset,
          steps_per_epoch=int(iq_count / config.bs),
          epochs=config.epochs,
          validation_data=validation_dataset,
          validation_steps=int(val_count / config.bs),
          callbacks=[
              log_code, reduce_lr, tensorboard, prog_bar, image_gen, saving,
              copy_keras, start_tensorboard, early_stopping
          ])
Пример #5
0
saving = tf.keras.callbacks.ModelCheckpoint(
    MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
    monitor='val_ssim',
    verbose=1,
    save_freq='epoch',
    mode='max',
    save_best_only=True)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
                                                 factor=0.2,
                                                 patience=3,
                                                 min_lr=0.002 * .001)
log_code = callbacks.LogCode(LOG_DIR, './trainer')
terminate = tf.keras.callbacks.TerminateOnNaN()
image_gen = callbacks.GenerateImages(model,
                                     validation_dataset,
                                     LOG_DIR,
                                     interval=int(train_count / config.bs))

# Fit the model
model.fit(train_dataset,
          steps_per_epoch=int(train_count / config.bs),
          epochs=config.epochs,
          validation_data=validation_dataset,
          validation_steps=int(val_count / config.bs),
          verbose=1,
          callbacks=[
              log_code, terminate, tensorboard, saving, reduce_lr, copy_keras,
              image_gen
          ])