Exemple #1
0
def main(args):
    # Load Data
    Dataset = utils.Dataset(image_dir=args.image_dir)
    train_dataset, train_count = Dataset.get_dataset(csv=args.train_csv,
                                                     batch_size=args.bs,
                                                     shape=(args.in_h,
                                                            args.in_w, 1))

    test_dataset, val_count = Dataset.get_dataset(csv=args.test_csv,
                                                  batch_size=args.bs,
                                                  shape=(args.in_h, args.in_w,
                                                         1))

    # Compile Model
    filters = [8, 16, 32, 64, 128]
    model = models.AttentionUnetModel(
        shape=(None, None, 1),
        Activation=tf.keras.layers.Activation(activation=tf.nn.relu),
        filters=filters,
        filter_shape=(3, 3))()

    model.compile(optimizer=tf.keras.optimizers.Adam(args.lr),
                  loss=utils.mse,
                  metrics=[utils.mse, utils.mae, utils.ssim, utils.psnr])

    # Generate Callbacks
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir=args.job_dir,
                                                 write_graph=True,
                                                 update_freq='epoch')
    saving = tf.keras.callbacks.ModelCheckpoint(
        args.model_dir + '/model.{epoch:02d}-{val_loss:.10f}.hdf5',
        monitor='val_loss',
        verbose=1,
        period=1,
        save_best_only=True)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
                                                     factor=0.2,
                                                     patience=2,
                                                     min_lr=args.lr * .001)
    log_code = callbacks.LogCode(args.job_dir, './trainer')
    copy_keras = callbacks.CopyKerasModel(args.model_dir, args.job_dir)

    # Fit the model
    model.fit(train_dataset,
              steps_per_epoch=int(train_count / args.bs),
              epochs=args.epochs,
              validation_data=test_dataset,
              validation_steps=int(val_count / args.bs),
              callbacks=[log_code, tensorboard, saving, copy_keras, reduce_lr])
Exemple #2
0
def scheduler(epoch):
  if epoch < config.startLRdecay:
    return 2e-4
  else:
    epochs_passed = epoch - config.startLRdecay
    decay_step = 2e-4 / (config.epochs - config.startLRdecay)
    return 2e-4 - epochs_remaining * decay_step

LRscheduler = callbacks.MultiLRScheduler(scheduler, training_models=[model.d_A, model.d_B, model.combined])
# Generate Callbacks
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, write_graph=True, update_freq='epoch')
start_tensorboard = callbacks.StartTensorBoard(LOG_DIR)

prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps', stateful_metrics=None)
log_code = callbacks.LogCode(LOG_DIR, './trainer')
copy_keras = callbacks.CopyKerasModel(MODEL_DIR, LOG_DIR)

saving = callbacks.MultiModelCheckpoint(MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
                                        monitor='val_ssim', verbose=1, freq='epoch', mode='max', save_best_only=False,
                                        save_weights_only=True,
                                        multi_models=[('g_AB', g_AB), ('g_BA', g_BA), ('d_A', d_A), ('d_B', d_B)])
                                            restore_best_weights=True, verbose=1)

image_gen = callbacks.GenerateImages(g_AB, test_X, test_Y, LOG_DIR, interval=int(dataset_count/config.bs))

# Fit the model
model.fit(train_X, train_Y,
    batch_size=config.bs,
    steps_per_epoch=(dataset_count // config.bs),
          epochs=config.epochs,
          validation_data=(test_X, test_Y),
def main(args):
    # Load Data
    Dataset = trainer.Dataset(image_dir=args.image_dir)
    train_dataset, train_count = Dataset.get_dataset(csv=args.train_csv,
                                                     batch_size=args.bs,
                                                     shape=(args.in_h,
                                                            args.in_w, 1))

    test_dataset, val_count = Dataset.get_dataset(csv=args.test_csv,
                                                  batch_size=args.bs,
                                                  shape=(args.in_h, args.in_w,
                                                         1))

    # Select and Compile Model
    filters = [8, 16, 32, 64, 128, 256, 512]
    dfilters = [8, 16, 32, 64, 128, 256, 512]

    generator_model = models.LeakyUnetModel(
        shape=(None, None, 1),
        Activation=tf.keras.layers.LeakyReLU(0.1),
        filters=filters,
        filter_shape=(3, 3))()

    discriminator_model = models.PatchDiscriminatorModel(
        shape=(args.in_h, args.in_w, 1),
        Activation=tf.keras.layers.LeakyReLU(0.1),
        filters=dfilters,
        filter_shape=(3, 3))()

    model = models.Pix2Pix(verbose=1,
                           shape=(None, None, 1),
                           g=generator_model,
                           d=discriminator_model,
                           patch_gan_hw=2**len(filters))

    model.compile(optimizer=tf.keras.optimizers.Adam(args.lr, 0.5),
                  d_loss=utils.mse,
                  g_loss=[utils.mse, utils.mae],
                  loss_weights=[1, 100],
                  metrics=[utils.mse, utils.mae, utils.ssim, utils.psnr])

    # Generate Callbacks
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir=args.job_dir,
                                                 write_graph=True,
                                                 update_freq='epoch')
    prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps',
                                                stateful_metrics=None)
    saving = tf.keras.callbacks.ModelCheckpoint(
        args.model_dir + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5',
        monitor='val_psnr',
        verbose=1,
        period=1,
        mode='max')

    save_multi_model = callbacks.SaveMultiModel([('g', generator_model),
                                                 ('d', discriminator_model)],
                                                args.model_dir)
    log_code = callbacks.LogCode(args.job_dir, './trainer')
    copy_keras = callbacks.CopyKerasModel(args.model_dir, args.job_dir)

    # Fit the model
    model.fit(train_dataset,
              steps_per_epoch=int(train_count / args.bs),
              epochs=args.epochs,
              validation_data=test_dataset,
              validation_steps=int(val_count / args.bs),
              callbacks=[
                  log_code, tensorboard, saving, save_multi_model, copy_keras,
                  prog_bar
              ])
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=config.job_dir,
                                             write_graph=True,
                                             update_freq=write_freq)
image_gen_val = callbacks.GenerateImages(model,
                                         validation_dataset,
                                         config.job_dir,
                                         interval=write_freq,
                                         postfix='val')
image_gen = callbacks.GenerateImages(model,
                                     train_dataset,
                                     config.job_dir,
                                     interval=write_freq,
                                     postfix='train')

# callbacks -- start tensorboard
start_tensorboard = callbacks.StartTensorBoard(config.job_dir)

# callbacks -- log code and trained models
log_code = callbacks.LogCode(config.job_dir, './trainer')
copy_keras = callbacks.CopyKerasModel(config.model_dir, config.job_dir)

model.fit(train_dataset,
          steps_per_epoch=int(train_count / config.batch_size),
          validation_data=validation_dataset,
          validation_steps=int(validation_count / config.batch_size),
          epochs=config.num_epochs,
          callbacks=[
              saving, tensorboard, start_tensorboard, log_code, copy_keras,
              image_gen, image_gen_val
          ])