def main(args): # Load Data Dataset = utils.Dataset(image_dir=args.image_dir) train_dataset, train_count = Dataset.get_dataset(csv=args.train_csv, batch_size=args.bs, shape=(args.in_h, args.in_w, 1)) test_dataset, val_count = Dataset.get_dataset(csv=args.test_csv, batch_size=args.bs, shape=(args.in_h, args.in_w, 1)) # Compile Model filters = [8, 16, 32, 64, 128] model = models.AttentionUnetModel( shape=(None, None, 1), Activation=tf.keras.layers.Activation(activation=tf.nn.relu), filters=filters, filter_shape=(3, 3))() model.compile(optimizer=tf.keras.optimizers.Adam(args.lr), loss=utils.mse, metrics=[utils.mse, utils.mae, utils.ssim, utils.psnr]) # Generate Callbacks tensorboard = tf.keras.callbacks.TensorBoard(log_dir=args.job_dir, write_graph=True, update_freq='epoch') saving = tf.keras.callbacks.ModelCheckpoint( args.model_dir + '/model.{epoch:02d}-{val_loss:.10f}.hdf5', monitor='val_loss', verbose=1, period=1, save_best_only=True) reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=args.lr * .001) log_code = callbacks.LogCode(args.job_dir, './trainer') copy_keras = callbacks.CopyKerasModel(args.model_dir, args.job_dir) # Fit the model model.fit(train_dataset, steps_per_epoch=int(train_count / args.bs), epochs=args.epochs, validation_data=test_dataset, validation_steps=int(val_count / args.bs), callbacks=[log_code, tensorboard, saving, copy_keras, reduce_lr])
metrics=[utils.ssim]) def scheduler(epoch): if epoch < config.startLRdecay: return 2e-4 else: epochs_passed = epoch - config.startLRdecay decay_step = 2e-4 / (config.epochs - config.startLRdecay) return 2e-4 - epochs_remaining * decay_step LRscheduler = callbacks.MultiLRScheduler(scheduler, training_models=[model.d_A, model.d_B, model.combined]) # Generate Callbacks tensorboard = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, write_graph=True, update_freq='epoch') start_tensorboard = callbacks.StartTensorBoard(LOG_DIR) prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps', stateful_metrics=None) log_code = callbacks.LogCode(LOG_DIR, './trainer') copy_keras = callbacks.CopyKerasModel(MODEL_DIR, LOG_DIR) saving = callbacks.MultiModelCheckpoint(MODEL_DIR + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5', monitor='val_ssim', verbose=1, freq='epoch', mode='max', save_best_only=False, save_weights_only=True, multi_models=[('g_AB', g_AB), ('g_BA', g_BA), ('d_A', d_A), ('d_B', d_B)]) restore_best_weights=True, verbose=1) image_gen = callbacks.GenerateImages(g_AB, test_X, test_Y, LOG_DIR, interval=int(dataset_count/config.bs)) # Fit the model model.fit(train_X, train_Y, batch_size=config.bs, steps_per_epoch=(dataset_count // config.bs), epochs=config.epochs,
def main(args): # Load Data Dataset = trainer.Dataset(image_dir=args.image_dir) train_dataset, train_count = Dataset.get_dataset(csv=args.train_csv, batch_size=args.bs, shape=(args.in_h, args.in_w, 1)) test_dataset, val_count = Dataset.get_dataset(csv=args.test_csv, batch_size=args.bs, shape=(args.in_h, args.in_w, 1)) # Select and Compile Model filters = [8, 16, 32, 64, 128, 256, 512] dfilters = [8, 16, 32, 64, 128, 256, 512] generator_model = models.LeakyUnetModel( shape=(None, None, 1), Activation=tf.keras.layers.LeakyReLU(0.1), filters=filters, filter_shape=(3, 3))() discriminator_model = models.PatchDiscriminatorModel( shape=(args.in_h, args.in_w, 1), Activation=tf.keras.layers.LeakyReLU(0.1), filters=dfilters, filter_shape=(3, 3))() model = models.Pix2Pix(verbose=1, shape=(None, None, 1), g=generator_model, d=discriminator_model, patch_gan_hw=2**len(filters)) model.compile(optimizer=tf.keras.optimizers.Adam(args.lr, 0.5), d_loss=utils.mse, g_loss=[utils.mse, utils.mae], loss_weights=[1, 100], metrics=[utils.mse, utils.mae, utils.ssim, utils.psnr]) # Generate Callbacks tensorboard = tf.keras.callbacks.TensorBoard(log_dir=args.job_dir, write_graph=True, update_freq='epoch') prog_bar = tf.keras.callbacks.ProgbarLogger(count_mode='steps', stateful_metrics=None) saving = tf.keras.callbacks.ModelCheckpoint( args.model_dir + '/model.{epoch:02d}-{val_ssim:.10f}.hdf5', monitor='val_psnr', verbose=1, period=1, mode='max') save_multi_model = callbacks.SaveMultiModel([('g', generator_model), ('d', discriminator_model)], args.model_dir) log_code = callbacks.LogCode(args.job_dir, './trainer') copy_keras = callbacks.CopyKerasModel(args.model_dir, args.job_dir) # Fit the model model.fit(train_dataset, steps_per_epoch=int(train_count / args.bs), epochs=args.epochs, validation_data=test_dataset, validation_steps=int(val_count / args.bs), callbacks=[ log_code, tensorboard, saving, save_multi_model, copy_keras, prog_bar ])
d_model = tf.keras.models.load_model(config.d_weight) # Callbacks write_freq = int(train_count / config.bs / 10) tensorboard = tf.keras.callbacks.TensorBoard(log_dir=config.job_dir, write_graph=True, update_freq=write_freq) saving = tf.keras.callbacks.ModelCheckpoint( config.model_dir + '/d' + '/model.{epoch:02d}-{val_loss:.5f}.hdf5', monitor='val_loss', verbose=1, save_freq='epoch', save_best_only=False) log_code = callbacks.LogCode(config.job_dir, './trainer') #copy_keras = callbacks.CopyKerasModel(config.model_dir, config.job_dir) #image_gen_val = callbacks.GenerateImages(generator_model, validation_dataset, config.job_dir, interval=write_freq, postfix='val') #image_gen = callbacks.GenerateImages(generator_model, train_dataset, config.job_dir, interval=write_freq, postfix='train') start_tensorboard = callbacks.StartTensorBoard(config.job_dir) # Fit model d_model.fit( train_dataset, steps_per_epoch=int(train_count / config.bs), epochs=config.epochs, validation_data=validation_dataset, validation_steps=int(validation_count / config.bs), verbose=1, callbacks=[