Пример #1
0
def main(_):
  runner_config = load_runner_config()

  if FLAGS.output_dir:
    tf.gfile.MakeDirs(FLAGS.output_dir)

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.estimator.tpu.RunConfig(
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=runner_config["save_checkpoints_steps"],
      keep_checkpoint_max=20,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=runner_config["iterations_per_loop"],
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  model_fn = model_fn_builder(runner_config)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  batch_size = runner_config["batch_size"]
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size)

  if FLAGS.runner_mode == "train":
    train_input_fn = input_fn_reader.create_input_fn(
        runner_config=runner_config,
        mode=tf.estimator.ModeKeys.TRAIN,
        drop_remainder=True)
    estimator.train(
        input_fn=train_input_fn, max_steps=runner_config["train_steps"])
  elif FLAGS.runner_mode == "eval":
    # TPU needs fixed shapes, so if the last batch is smaller, we drop it.
    eval_input_fn = input_fn_reader.create_input_fn(
        runner_config=runner_config,
        mode=tf.estimator.ModeKeys.EVAL,
        drop_remainder=True)

    for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600):
      result = estimator.evaluate(input_fn=eval_input_fn)
      for key in sorted(result):
        logging.info("  %s = %s", key, str(result[key]))
Пример #2
0
def main(_):
  runner_config = load_runner_config()

  if FLAGS.output_dir:
    tf.io.gfile.makedirs(FLAGS.output_dir)

  train_model = model_fn_builder(runner_config, tf_estimator.ModeKeys.TRAIN)
  optimizer = tf.keras.optimizers.Adam()
  train_input_fn = input_fn_reader.create_input_fn(
      runner_config=runner_config,
      mode=tf_estimator.ModeKeys.TRAIN,
      drop_remainder=True)
  params = {"batch_size": runner_config["batch_size"]}
  train_ds = train_input_fn(params)
  train_loss = tf.keras.metrics.Mean(name="train_loss")

  @tf.function
  def train_step(features):
    with tf.GradientTape() as tape:
      logits = train_model(features["projection"], features["seq_length"])
      loss = compute_loss(logits, features["label"],
                          runner_config["model_config"],
                          tf_estimator.ModeKeys.TRAIN)
    gradients = tape.gradient(loss, train_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, train_model.trainable_variables))
    train_loss(loss)

  for epoch in range(1):
    train_loss.reset_states()
    for features in train_ds:
      train_step(features)
      step = optimizer.iterations.numpy()
      if step % 100 == 0:
        logging.info("Running step %s in epoch %s", step, epoch)
        logging.info("Training loss: %s, epoch: %s, step: %s",
                     round(train_loss.result().numpy(), 4), epoch, step)