コード例 #1
0
    def test_custom_model_checkpoint(self):
        ckpt_dir = '/tmp/tf3d/callback_util_test'
        if tf.io.gfile.exists(ckpt_dir):
            tf.io.gfile.rmtree(ckpt_dir)

        callback = callback_utils.CustomModelCheckpoint(ckpt_dir=ckpt_dir,
                                                        save_epoch_freq=1,
                                                        max_to_keep=5)
        model = tf.keras.Model()
        callback.set_model(model)
        callback.on_epoch_begin(epoch=0, logs=None)
        callback.on_epoch_end(epoch=1, logs=None)
        self.assertNotEmpty((tf.io.gfile.glob(os.path.join(ckpt_dir, '*'))))
コード例 #2
0
def train(strategy,
          write_path,
          learning_rate_fn=None,
          model_class=None,
          input_fn=None,
          optimizer_fn=tf.keras.optimizers.SGD):
  """A function that build the model and train.

  Args:
    strategy: A tf.distribute.Strategy object.
    write_path: A string of path to write training logs and checkpoints.
    learning_rate_fn: A learning rate function.
    model_class: The class of the model to train.
    input_fn: A input function that returns a tf.data.Dataset.
    optimizer_fn: A function that returns the optimizer.
  """
  if learning_rate_fn is None:
    raise ValueError('learning_rate_fn is not set.')

  with strategy.scope():
    logging.info('Model creation starting')
    model = model_class(
        train_dir=os.path.join(write_path, 'train'),
        summary_log_freq=FLAGS.log_freq)

    logging.info('Model compile starting')
    model.compile(optimizer=optimizer_fn(learning_rate=learning_rate_fn()))

    backup_checkpoint_callback = tf.keras.callbacks.experimental.BackupAndRestore(
        backup_dir=os.path.join(write_path, 'backup_model'))
    checkpoint_callback = callback_utils.CustomModelCheckpoint(
        ckpt_dir=os.path.join(write_path, 'model'),
        save_epoch_freq=1,
        max_to_keep=3)

    logging.info('Input creation starting')
    total_batch_size = FLAGS.batch_size * FLAGS.num_workers * FLAGS.num_gpus
    inputs = input_fn(is_training=True, batch_size=total_batch_size)
    logging.info(
        'Model fit starting for %d epochs, %d step per epoch, total batch size:%d',
        flags.FLAGS.num_epochs, flags.FLAGS.num_steps_per_epoch,
        total_batch_size)

  model.fit(
      x=inputs,
      callbacks=[backup_checkpoint_callback, checkpoint_callback],
      steps_per_epoch=FLAGS.num_steps_per_epoch,
      epochs=FLAGS.num_epochs,
      verbose=1 if FLAGS.run_functions_eagerly else 2)
  model.close_writer()