def test_custom_model_checkpoint(self): ckpt_dir = '/tmp/tf3d/callback_util_test' if tf.io.gfile.exists(ckpt_dir): tf.io.gfile.rmtree(ckpt_dir) callback = callback_utils.CustomModelCheckpoint(ckpt_dir=ckpt_dir, save_epoch_freq=1, max_to_keep=5) model = tf.keras.Model() callback.set_model(model) callback.on_epoch_begin(epoch=0, logs=None) callback.on_epoch_end(epoch=1, logs=None) self.assertNotEmpty((tf.io.gfile.glob(os.path.join(ckpt_dir, '*'))))
def train(strategy, write_path, learning_rate_fn=None, model_class=None, input_fn=None, optimizer_fn=tf.keras.optimizers.SGD): """A function that build the model and train. Args: strategy: A tf.distribute.Strategy object. write_path: A string of path to write training logs and checkpoints. learning_rate_fn: A learning rate function. model_class: The class of the model to train. input_fn: A input function that returns a tf.data.Dataset. optimizer_fn: A function that returns the optimizer. """ if learning_rate_fn is None: raise ValueError('learning_rate_fn is not set.') with strategy.scope(): logging.info('Model creation starting') model = model_class( train_dir=os.path.join(write_path, 'train'), summary_log_freq=FLAGS.log_freq) logging.info('Model compile starting') model.compile(optimizer=optimizer_fn(learning_rate=learning_rate_fn())) backup_checkpoint_callback = tf.keras.callbacks.experimental.BackupAndRestore( backup_dir=os.path.join(write_path, 'backup_model')) checkpoint_callback = callback_utils.CustomModelCheckpoint( ckpt_dir=os.path.join(write_path, 'model'), save_epoch_freq=1, max_to_keep=3) logging.info('Input creation starting') total_batch_size = FLAGS.batch_size * FLAGS.num_workers * FLAGS.num_gpus inputs = input_fn(is_training=True, batch_size=total_batch_size) logging.info( 'Model fit starting for %d epochs, %d step per epoch, total batch size:%d', flags.FLAGS.num_epochs, flags.FLAGS.num_steps_per_epoch, total_batch_size) model.fit( x=inputs, callbacks=[backup_checkpoint_callback, checkpoint_callback], steps_per_epoch=FLAGS.num_steps_per_epoch, epochs=FLAGS.num_epochs, verbose=1 if FLAGS.run_functions_eagerly else 2) model.close_writer()